[2.0] rm -rf tensorflow/contrib
PiperOrigin-RevId: 269816906
This commit is contained in:
parent
776b99925c
commit
ffc25308ce
@ -849,7 +849,7 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"api_version_2": [],
|
||||
"//conditions:default": ["//tensorflow/contrib:contrib_py"],
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
":tensorflow_py_no_contrib",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
|
@ -1,175 +0,0 @@
|
||||
# Description:
|
||||
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows")
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "contrib_py",
|
||||
srcs = glob(
|
||||
["**/*.py"],
|
||||
exclude = [
|
||||
"**/*_test.py",
|
||||
],
|
||||
),
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/all_reduce",
|
||||
"//tensorflow/contrib/batching:batch_py",
|
||||
"//tensorflow/contrib/bayesflow:bayesflow_py",
|
||||
"//tensorflow/contrib/boosted_trees:init_py",
|
||||
"//tensorflow/contrib/checkpoint/python:checkpoint",
|
||||
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
|
||||
"//tensorflow/contrib/compiler:compiler_py",
|
||||
"//tensorflow/contrib/compiler:xla",
|
||||
"//tensorflow/contrib/autograph",
|
||||
"//tensorflow/contrib/constrained_optimization",
|
||||
"//tensorflow/contrib/copy_graph:copy_graph_py",
|
||||
"//tensorflow/contrib/crf:crf_py",
|
||||
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
|
||||
"//tensorflow/contrib/data",
|
||||
"//tensorflow/contrib/deprecated:deprecated_py",
|
||||
"//tensorflow/contrib/distribute:distribute",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/eager/python:tfe",
|
||||
"//tensorflow/contrib/estimator:estimator_py",
|
||||
"//tensorflow/contrib/factorization:factorization_py",
|
||||
"//tensorflow/contrib/feature_column:feature_column_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/hadoop",
|
||||
"//tensorflow/contrib/hooks",
|
||||
"//tensorflow/contrib/image:distort_image_py",
|
||||
"//tensorflow/contrib/image:image_py",
|
||||
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/keras",
|
||||
"//tensorflow/contrib/kernel_methods",
|
||||
"//tensorflow/contrib/labeled_tensor",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/learn:head_test_lib",
|
||||
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
|
||||
"//tensorflow/contrib/libsvm",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/lookup:lookup_py",
|
||||
"//tensorflow/contrib/losses:losses_py",
|
||||
"//tensorflow/contrib/losses:metric_learning_py",
|
||||
"//tensorflow/contrib/memory_stats:memory_stats_py",
|
||||
"//tensorflow/contrib/meta_graph_transform",
|
||||
"//tensorflow/contrib/metrics:metrics_py",
|
||||
"//tensorflow/contrib/mixed_precision:mixed_precision",
|
||||
"//tensorflow/contrib/model_pruning",
|
||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
|
||||
"//tensorflow/contrib/nn:nn_py",
|
||||
"//tensorflow/contrib/opt:opt_py",
|
||||
"//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
|
||||
"//tensorflow/contrib/periodic_resample:init_py",
|
||||
"//tensorflow/contrib/predictor",
|
||||
"//tensorflow/contrib/proto",
|
||||
"//tensorflow/contrib/quantization:quantization_py",
|
||||
"//tensorflow/contrib/quantize:quantize_graph",
|
||||
"//tensorflow/contrib/receptive_field:receptive_field_py",
|
||||
"//tensorflow/contrib/recurrent:recurrent_py",
|
||||
"//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
|
||||
"//tensorflow/contrib/resampler:resampler_py",
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
"//tensorflow/contrib/rpc",
|
||||
"//tensorflow/contrib/saved_model:saved_model_py",
|
||||
"//tensorflow/contrib/seq2seq:seq2seq_py",
|
||||
"//tensorflow/contrib/signal:signal_py",
|
||||
"//tensorflow/contrib/slim",
|
||||
"//tensorflow/contrib/slim:nets",
|
||||
"//tensorflow/contrib/solvers:solvers_py",
|
||||
"//tensorflow/contrib/sparsemax:sparsemax_py",
|
||||
"//tensorflow/contrib/specs",
|
||||
"//tensorflow/contrib/staging",
|
||||
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
|
||||
"//tensorflow/contrib/stateless",
|
||||
"//tensorflow/contrib/summary:summary",
|
||||
"//tensorflow/contrib/tensor_forest:init_py",
|
||||
"//tensorflow/contrib/tensorboard",
|
||||
"//tensorflow/contrib/testing:testing_py",
|
||||
"//tensorflow/contrib/text:text_py",
|
||||
"//tensorflow/contrib/tfprof",
|
||||
"//tensorflow/contrib/timeseries",
|
||||
"//tensorflow/contrib/tpu",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/fused_conv:fused_conv_py",
|
||||
"//tensorflow/contrib/tensorrt:init_py",
|
||||
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_gcp_support": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/bigtable",
|
||||
"//tensorflow/contrib/cloud:cloud_py",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "contrib_kernels",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
|
||||
"//tensorflow/contrib/factorization/kernels:all_kernels",
|
||||
"//tensorflow/contrib/hadoop:dataset_kernels",
|
||||
"//tensorflow/contrib/image:image_ops_kernels",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
|
||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels",
|
||||
"//tensorflow/contrib/rnn:all_kernels",
|
||||
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:model_ops_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||
"//tensorflow/contrib/text:all_kernels",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/contrib/tensorrt:trt_op_kernels",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "contrib_ops_op_lib",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
|
||||
"//tensorflow/contrib/factorization:all_ops",
|
||||
"//tensorflow/contrib/framework:all_ops",
|
||||
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
|
||||
"//tensorflow/contrib/rnn:all_ops",
|
||||
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:model_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/contrib/text:all_ops",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||
]),
|
||||
)
|
@ -1,23 +0,0 @@
|
||||
# TensorFlow contrib
|
||||
|
||||
Any code in this directory is not officially supported, and may change or be
|
||||
removed at any time without notice.
|
||||
|
||||
The contrib directory contains project directories, each of which has designated
|
||||
owners. It is meant to contain features and contributions that eventually should
|
||||
get merged into core TensorFlow, but whose interfaces may still change, or which
|
||||
require some testing to see whether they can find broader acceptance. We are
|
||||
trying to keep duplication within contrib to a minimum, so you may be asked to
|
||||
refactor code in contrib to use some feature inside core or in another project
|
||||
in contrib rather than reimplementing the feature.
|
||||
|
||||
When adding a project, please stick to the following directory structure:
|
||||
Create a project directory in `contrib/`, and mirror the portions of the
|
||||
TensorFlow tree that your project requires underneath `contrib/my_project/`.
|
||||
|
||||
For example, let's say you create foo ops in two files: `foo_ops.py` and
|
||||
`foo_ops_test.py`. If you were to merge those files directly into TensorFlow,
|
||||
they would live in `tensorflow/python/ops/foo_ops.py` and
|
||||
`tensorflow/python/kernel_tests/foo_ops_test.py`. In `contrib/`, they are part
|
||||
of project `foo`, and their full paths are `contrib/foo/python/ops/foo_ops.py`
|
||||
and `contrib/foo/python/kernel_tests/foo_ops_test.py`.
|
@ -1,122 +0,0 @@
|
||||
# pylint: disable=g-import-not-at-top
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contrib module containing volatile or experimental code.
|
||||
|
||||
Warning: The `tf.contrib` module will not be included in TensorFlow 2.0. Many
|
||||
of its submodules have been integrated into TensorFlow core, or spun-off into
|
||||
other projects like [`tensorflow_io`](https://github.com/tensorflow/io), or
|
||||
[`tensorflow_addons`](https://github.com/tensorflow/addons). For instructions
|
||||
on how to upgrade see the
|
||||
[Migration guide](https://www.tensorflow.org/beta/guide/migration_guide).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import autograph
|
||||
from tensorflow.contrib import batching
|
||||
from tensorflow.contrib import bayesflow
|
||||
from tensorflow.contrib import checkpoint
|
||||
from tensorflow.contrib import cluster_resolver
|
||||
from tensorflow.contrib import compiler
|
||||
from tensorflow.contrib import constrained_optimization
|
||||
from tensorflow.contrib import copy_graph
|
||||
from tensorflow.contrib import crf
|
||||
from tensorflow.contrib import cudnn_rnn
|
||||
from tensorflow.contrib import data
|
||||
from tensorflow.contrib import deprecated
|
||||
from tensorflow.contrib import distribute
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib import estimator
|
||||
from tensorflow.contrib import factorization
|
||||
from tensorflow.contrib import feature_column
|
||||
from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import image
|
||||
from tensorflow.contrib import input_pipeline
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import keras
|
||||
from tensorflow.contrib import kernel_methods
|
||||
from tensorflow.contrib import labeled_tensor
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import legacy_seq2seq
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import memory_stats
|
||||
from tensorflow.contrib import metrics
|
||||
from tensorflow.contrib import mixed_precision
|
||||
from tensorflow.contrib import model_pruning
|
||||
from tensorflow.contrib import nn
|
||||
from tensorflow.contrib import opt
|
||||
from tensorflow.contrib import periodic_resample
|
||||
from tensorflow.contrib import predictor
|
||||
from tensorflow.contrib import proto
|
||||
from tensorflow.contrib import quantization
|
||||
from tensorflow.contrib import quantize
|
||||
from tensorflow.contrib import reduce_slice_ops
|
||||
from tensorflow.contrib import resampler
|
||||
from tensorflow.contrib import rnn
|
||||
from tensorflow.contrib import rpc
|
||||
from tensorflow.contrib import saved_model
|
||||
from tensorflow.contrib import seq2seq
|
||||
from tensorflow.contrib import signal
|
||||
from tensorflow.contrib import slim
|
||||
from tensorflow.contrib import solvers
|
||||
from tensorflow.contrib import sparsemax
|
||||
from tensorflow.contrib import staging
|
||||
from tensorflow.contrib import stat_summarizer
|
||||
from tensorflow.contrib import stateless
|
||||
from tensorflow.contrib import tensor_forest
|
||||
from tensorflow.contrib import tensorboard
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.contrib import tfprof
|
||||
from tensorflow.contrib import timeseries
|
||||
from tensorflow.contrib import tpu
|
||||
from tensorflow.contrib import training
|
||||
from tensorflow.contrib import util
|
||||
from tensorflow.contrib.eager.python import tfe as eager
|
||||
from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2
|
||||
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
|
||||
from tensorflow.contrib.recurrent.python import recurrent_api as recurrent
|
||||
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
|
||||
from tensorflow.contrib.specs import python as specs
|
||||
from tensorflow.contrib.summary import summary
|
||||
|
||||
if os.name != "nt" and platform.machine() != "s390x":
|
||||
try:
|
||||
from tensorflow.contrib import cloud
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
ffmpeg = LazyLoader("ffmpeg", globals(),
|
||||
"tensorflow.contrib.ffmpeg")
|
||||
del os
|
||||
del platform
|
||||
|
||||
del LazyLoader
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
@ -1,33 +0,0 @@
|
||||
# Description:
|
||||
# All-reduce implementations.
|
||||
# APIs are subject to change. Eventually to be replaced by equivalent
|
||||
# functionality within TensorFlow core.
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "all_reduce_py",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":all_reduce",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "all_reduce",
|
||||
srcs = [
|
||||
"python/all_reduce.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python/distribute:all_reduce",
|
||||
],
|
||||
)
|
@ -1,39 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""All-reduce implementations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||
from tensorflow.contrib.all_reduce.python.all_reduce import *
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||
|
||||
_allowed_symbols = [
|
||||
'build_ring_all_reduce',
|
||||
'build_recursive_hd_all_reduce',
|
||||
'build_shuffle_all_reduce',
|
||||
'build_nccl_all_reduce',
|
||||
'build_nccl_then_ring',
|
||||
'build_nccl_then_recursive_hd',
|
||||
'build_nccl_then_shuffle',
|
||||
'build_shuffle_then_ring',
|
||||
'build_shuffle_then_shuffle'
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.python.distribute.all_reduce import *
|
@ -1,87 +0,0 @@
|
||||
# Description:
|
||||
# JNI-based Java inference interface for TensorFlow.
|
||||
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"jni/version_script.lds",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "android_tensorflow_inference_jni_srcs",
|
||||
srcs = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "android_tensorflow_inference_jni",
|
||||
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
"//tensorflow/java/src/main/native",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# JAR with Java bindings to TF.
|
||||
android_library(
|
||||
name = "android_tensorflow_inference_java",
|
||||
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
)
|
||||
|
||||
# Build the native .so.
|
||||
# bazel build //tensorflow/contrib/android:libtensorflow_inference.so \
|
||||
# --crosstool_top=//external:android/crosstool \
|
||||
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
# --cpu=armeabi-v7a
|
||||
LINKER_SCRIPT = "//tensorflow/contrib/android:jni/version_script.lds"
|
||||
|
||||
cc_binary(
|
||||
name = "libtensorflow_inference.so",
|
||||
copts = tf_copts() + [
|
||||
"-ffunction-sections",
|
||||
"-fdata-sections",
|
||||
],
|
||||
linkopts = if_android([
|
||||
"-landroid",
|
||||
"-latomic",
|
||||
"-ldl",
|
||||
"-llog",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-s",
|
||||
"-Wl,--gc-sections",
|
||||
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
|
||||
]),
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":android_tensorflow_inference_jni",
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
LINKER_SCRIPT,
|
||||
],
|
||||
)
|
@ -1,95 +0,0 @@
|
||||
# Android TensorFlow support
|
||||
|
||||
This directory defines components (a native `.so` library and a Java JAR)
|
||||
geared towards supporting TensorFlow on Android. This includes:
|
||||
|
||||
- The [TensorFlow Java API](../../java/README.md)
|
||||
- A `TensorFlowInferenceInterface` class that provides a smaller API
|
||||
surface suitable for inference and summarizing performance of model execution.
|
||||
|
||||
For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java)
|
||||
in the [TensorFlow Android Demo](../../examples/android).
|
||||
|
||||
For prebuilt libraries, see the
|
||||
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
|
||||
page for a recent build.
|
||||
|
||||
The TensorFlow Inference Interface is also available as a
|
||||
[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
|
||||
(see the tensorflow-android directory) and can be included quite simply in your
|
||||
android project with a couple of lines in the project's `build.gradle` file:
|
||||
|
||||
```
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile 'org.tensorflow:tensorflow-android:+'
|
||||
}
|
||||
```
|
||||
|
||||
This will tell Gradle to use the
|
||||
[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
of the TensorFlow AAR that has been released to
|
||||
[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/).
|
||||
You may replace the `+` with an explicit version label if you wish to
|
||||
use a specific release of TensorFlow in your app.
|
||||
|
||||
To build the libraries yourself (if, for example, you want to support custom
|
||||
TensorFlow operators), pick your preferred approach below:
|
||||
|
||||
### Bazel
|
||||
|
||||
First follow the Bazel setup instructions described in
|
||||
[tensorflow/examples/android/README.md](../../examples/android/README.md)
|
||||
|
||||
Then, to build the native TF library:
|
||||
|
||||
```
|
||||
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
|
||||
--crosstool_top=//external:android/crosstool \
|
||||
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
|
||||
--cxxopt=-std=c++11 \
|
||||
--cpu=armeabi-v7a
|
||||
```
|
||||
|
||||
Replacing `armeabi-v7a` with your desired target architecture.
|
||||
|
||||
The library will be located at:
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
|
||||
```
|
||||
|
||||
To build the Java counterpart:
|
||||
|
||||
```
|
||||
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
|
||||
```
|
||||
|
||||
You will find the JAR file at:
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
|
||||
```
|
||||
|
||||
### CMake
|
||||
|
||||
For documentation on building a self-contained AAR file with cmake, see
|
||||
[tensorflow/contrib/android/cmake](cmake).
|
||||
|
||||
|
||||
### Makefile
|
||||
|
||||
For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
|
||||
|
||||
|
||||
## AssetManagerFileSystem
|
||||
|
||||
This directory also contains a TensorFlow filesystem supporting the Android
|
||||
asset manager. This may be useful when writing native (C++) code that is tightly
|
||||
coupled with TensorFlow. For typical usage, the library above will be
|
||||
sufficient.
|
@ -1,272 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/android/asset_manager_filesystem.h"
|
||||
|
||||
#include <unistd.h>
|
||||
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
string RemoveSuffix(const string& name, const string& suffix) {
|
||||
string output(name);
|
||||
StringPiece piece(output);
|
||||
absl::ConsumeSuffix(&piece, suffix);
|
||||
return string(piece);
|
||||
}
|
||||
|
||||
// Closes the given AAsset when variable is destructed.
|
||||
class ScopedAsset {
|
||||
public:
|
||||
ScopedAsset(AAsset* asset) : asset_(asset) {}
|
||||
~ScopedAsset() {
|
||||
if (asset_ != nullptr) {
|
||||
AAsset_close(asset_);
|
||||
}
|
||||
}
|
||||
|
||||
AAsset* get() const { return asset_; }
|
||||
|
||||
private:
|
||||
AAsset* asset_;
|
||||
};
|
||||
|
||||
// Closes the given AAssetDir when variable is destructed.
|
||||
class ScopedAssetDir {
|
||||
public:
|
||||
ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {}
|
||||
~ScopedAssetDir() {
|
||||
if (asset_dir_ != nullptr) {
|
||||
AAssetDir_close(asset_dir_);
|
||||
}
|
||||
}
|
||||
|
||||
AAssetDir* get() const { return asset_dir_; }
|
||||
|
||||
private:
|
||||
AAssetDir* asset_dir_;
|
||||
};
|
||||
|
||||
class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion {
|
||||
public:
|
||||
ReadOnlyMemoryRegionFromAsset(std::unique_ptr<char[]> data, uint64 length)
|
||||
: data_(std::move(data)), length_(length) {}
|
||||
~ReadOnlyMemoryRegionFromAsset() override = default;
|
||||
|
||||
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
|
||||
uint64 length() override { return length_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<char[]> data_;
|
||||
uint64 length_;
|
||||
};
|
||||
|
||||
// Note that AAssets are not thread-safe and cannot be used across threads.
|
||||
// However, AAssetManager is. Because RandomAccessFile must be thread-safe and
|
||||
// used across threads, new AAssets must be created for every access.
|
||||
// TODO(tylerrhodes): is there a more efficient way to do this?
|
||||
class RandomAccessFileFromAsset : public RandomAccessFile {
|
||||
public:
|
||||
RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name)
|
||||
: asset_manager_(asset_manager), file_name_(name) {}
|
||||
~RandomAccessFileFromAsset() override = default;
|
||||
|
||||
Status Read(uint64 offset, size_t to_read, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
auto asset = ScopedAsset(AAssetManager_open(
|
||||
asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", file_name_, " not found.");
|
||||
}
|
||||
|
||||
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
|
||||
off64_t length = AAsset_getLength64(asset.get());
|
||||
if (new_offset < 0) {
|
||||
*result = StringPiece(scratch, 0);
|
||||
return errors::OutOfRange("Read after file end.");
|
||||
}
|
||||
const off64_t region_left =
|
||||
std::min(length - new_offset, static_cast<off64_t>(to_read));
|
||||
int read = AAsset_read(asset.get(), scratch, region_left);
|
||||
if (read < 0) {
|
||||
return errors::Internal("Error reading from asset.");
|
||||
}
|
||||
*result = StringPiece(scratch, region_left);
|
||||
return (region_left == to_read)
|
||||
? Status::OK()
|
||||
: errors::OutOfRange("Read less bytes than requested.");
|
||||
}
|
||||
|
||||
private:
|
||||
AAssetManager* asset_manager_;
|
||||
string file_name_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager,
|
||||
const string& prefix)
|
||||
: asset_manager_(asset_manager), prefix_(prefix) {}
|
||||
|
||||
Status AssetManagerFileSystem::FileExists(const string& fname) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
result->reset(new RandomAccessFileFromAsset(asset_manager_, path));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||
const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
|
||||
off64_t start, length;
|
||||
int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length);
|
||||
std::unique_ptr<char[]> data;
|
||||
if (fd >= 0) {
|
||||
data.reset(new char[length]);
|
||||
ssize_t result = pread(fd, data.get(), length, start);
|
||||
if (result < 0) {
|
||||
return errors::Internal("Error reading from file ", fname,
|
||||
" using 'read': ", result);
|
||||
}
|
||||
if (result != length) {
|
||||
return errors::Internal("Expected size does not match size read: ",
|
||||
"Expected ", length, " vs. read ", result);
|
||||
}
|
||||
close(fd);
|
||||
} else {
|
||||
length = AAsset_getLength64(asset.get());
|
||||
data.reset(new char[length]);
|
||||
const void* asset_buffer = AAsset_getBuffer(asset.get());
|
||||
if (asset_buffer == nullptr) {
|
||||
return errors::Internal("Error reading ", fname, " from asset manager.");
|
||||
}
|
||||
memcpy(data.get(), asset_buffer, length);
|
||||
}
|
||||
result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir,
|
||||
std::vector<string>* r) {
|
||||
std::string path = NormalizeDirectoryPath(prefixed_dir);
|
||||
auto dir =
|
||||
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||
if (dir.get() == nullptr) {
|
||||
return errors::NotFound("Directory ", prefixed_dir, " not found.");
|
||||
}
|
||||
const char* next_file = AAssetDir_getNextFileName(dir.get());
|
||||
while (next_file != nullptr) {
|
||||
r->push_back(next_file);
|
||||
next_file = AAssetDir_getNextFileName(dir.get());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) {
|
||||
// If fname corresponds to a directory, return early. It doesn't map to an
|
||||
// AAsset, and would otherwise return NotFound.
|
||||
if (DirectoryExists(fname)) {
|
||||
*s = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
string path = RemoveAssetPrefix(fname);
|
||||
auto asset = ScopedAsset(
|
||||
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
|
||||
if (asset.get() == nullptr) {
|
||||
return errors::NotFound("File ", fname, " not found.");
|
||||
}
|
||||
*s = AAsset_getLength64(asset.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) {
|
||||
uint64 size;
|
||||
stat->is_directory = DirectoryExists(fname);
|
||||
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
|
||||
stat->length = size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
|
||||
return RemoveSuffix(RemoveAssetPrefix(fname), "/");
|
||||
}
|
||||
|
||||
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
|
||||
StringPiece piece(name);
|
||||
absl::ConsumePrefix(&piece, prefix_);
|
||||
return string(piece);
|
||||
}
|
||||
|
||||
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
|
||||
std::string path = NormalizeDirectoryPath(fname);
|
||||
auto dir =
|
||||
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
|
||||
// Note that openDir will return something even if the directory doesn't
|
||||
// exist. Therefore, we need to ensure one file exists in the folder.
|
||||
return AAssetDir_getNextFileName(dir.get()) != NULL;
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) {
|
||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
|
||||
}
|
||||
|
||||
Status AssetManagerFileSystem::NewWritableFile(
|
||||
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::NewAppendableFile(
|
||||
const string& fname, std::unique_ptr<WritableFile>* result) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::DeleteFile(const string& f) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::CreateDir(const string& d) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::DeleteDir(const string& d) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) {
|
||||
return errors::Unimplemented("Asset storage is read only.");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,85 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||
#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
||||
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// FileSystem that uses Android's AAssetManager. Once initialized with a given
|
||||
// AAssetManager, files in the given AAssetManager can be accessed through the
|
||||
// prefix given when registered with the TensorFlow Env.
|
||||
// Note that because APK assets are immutable, any operation that tries to
|
||||
// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED.
|
||||
class AssetManagerFileSystem : public FileSystem {
|
||||
public:
|
||||
// Initialize an AssetManagerFileSystem. Note that this does not register the
|
||||
// file system with TensorFlow.
|
||||
// asset_manager - Non-null Android AAssetManager that backs this file
|
||||
// system. The asset manager is not owned by this file system, and must
|
||||
// outlive this class.
|
||||
// prefix - Common prefix to strip from all file URIs before passing them to
|
||||
// the asset_manager. This is required because TensorFlow gives the entire
|
||||
// file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths
|
||||
// relative to its base directory.
|
||||
AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix);
|
||||
~AssetManagerFileSystem() override = default;
|
||||
|
||||
Status FileExists(const string& fname) override;
|
||||
Status NewRandomAccessFile(
|
||||
const string& filename,
|
||||
std::unique_ptr<RandomAccessFile>* result) override;
|
||||
Status NewReadOnlyMemoryRegionFromFile(
|
||||
const string& filename,
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
|
||||
|
||||
Status GetFileSize(const string& f, uint64* s) override;
|
||||
// Currently just returns size.
|
||||
Status Stat(const string& fname, FileStatistics* stat) override;
|
||||
Status GetChildren(const string& dir, std::vector<string>* r) override;
|
||||
|
||||
// All these functions return Unimplemented error. Asset storage is
|
||||
// read only.
|
||||
Status NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
Status NewAppendableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override;
|
||||
Status DeleteFile(const string& f) override;
|
||||
Status CreateDir(const string& d) override;
|
||||
Status DeleteDir(const string& d) override;
|
||||
Status RenameFile(const string& s, const string& t) override;
|
||||
|
||||
Status GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) override;
|
||||
|
||||
private:
|
||||
string RemoveAssetPrefix(const string& name);
|
||||
|
||||
// Return a string path that can be passed into AAssetManager functions.
|
||||
// For example, 'my_prefix://some/dir/' would return 'some/dir'.
|
||||
string NormalizeDirectoryPath(const string& fname);
|
||||
bool DirectoryExists(const std::string& fname);
|
||||
|
||||
AAssetManager* asset_manager_;
|
||||
string prefix_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
|
@ -1,80 +0,0 @@
|
||||
#
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
cmake_minimum_required(VERSION 3.4.1)
|
||||
include(ExternalProject)
|
||||
|
||||
# TENSORFLOW_ROOT_DIR:
|
||||
# root directory of tensorflow repo
|
||||
# used for shared source files and pre-built libs
|
||||
get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
|
||||
set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
|
||||
|
||||
add_library(lib_proto STATIC IMPORTED )
|
||||
set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
|
||||
${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
|
||||
|
||||
add_library(lib_nsync STATIC IMPORTED )
|
||||
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
|
||||
${TARGET_NSYNC_LIB}/lib/libnsync.a)
|
||||
|
||||
add_library(lib_tf STATIC IMPORTED )
|
||||
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
|
||||
${PREBUILT_DIR}/lib/libtensorflow-core.a)
|
||||
# Change to compile flags should be replicated into bazel build file
|
||||
# TODO: Consider options other than -O2 for binary size.
|
||||
# e.g. -Os for gcc, and -Oz for clang.
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
|
||||
-std=c++11 -fno-rtti -fno-exceptions \
|
||||
-O2 -Wno-narrowing -fomit-frame-pointer \
|
||||
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
|
||||
-ftemplate-depth=900 \
|
||||
-DGOOGLE_PROTOBUF_NO_RTTI \
|
||||
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
|
||||
-Wl,--allow-multiple-definition \
|
||||
-Wl,--whole-archive \
|
||||
-fPIE -pie -v")
|
||||
file(GLOB tensorflow_inference_sources
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
|
||||
file(GLOB java_api_native_sources
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)
|
||||
|
||||
add_library(tensorflow_inference SHARED
|
||||
${tensorflow_inference_sources}
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
|
||||
${java_api_native_sources})
|
||||
|
||||
# Include libraries needed for hello-jni lib
|
||||
target_link_libraries(tensorflow_inference
|
||||
android
|
||||
dl
|
||||
log
|
||||
m
|
||||
z
|
||||
lib_tf
|
||||
lib_proto
|
||||
lib_nsync)
|
||||
|
||||
include_directories(
|
||||
${PREBUILT_DIR}/proto
|
||||
${PREBUILT_DIR}/protobuf/include
|
||||
${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
|
||||
${TENSORFLOW_ROOT_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..)
|
@ -1,48 +0,0 @@
|
||||
TensorFlow-Android-Inference
|
||||
============================
|
||||
This directory contains CMake support for building the Android Java Inference
|
||||
interface to the TensorFlow native APIs.
|
||||
|
||||
See [tensorflow/contrib/android](..) for more details about the library, and
|
||||
instructions for building with Bazel.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Add TensorFlow-Android-Inference as a dependency of your Android application
|
||||
|
||||
* settings.gradle
|
||||
|
||||
```
|
||||
include ':TensorFlow-Android-Inference'
|
||||
findProject(":TensorFlow-Android-Inference").projectDir =
|
||||
new File("${/path/to/tensorflow_repo}/contrib/android/cmake")
|
||||
```
|
||||
|
||||
* application's build.gradle (adding dependency):
|
||||
|
||||
```
|
||||
debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
|
||||
releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
|
||||
```
|
||||
Note: this makes native code in the lib traceable from your app.
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
|
||||
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
|
||||
is used in build.gradle to build it. It DOES take time to build the core libs;
|
||||
so, by default, it is commented out to avoid confusion (otherwise
|
||||
Android Studio would appear to hang during opening the project).
|
||||
To enable it, refer to the comment in
|
||||
|
||||
* build.gradle
|
||||
|
||||
Output
|
||||
------
|
||||
- TensorFlow-Inference-debug.aar
|
||||
- TensorFlow-Inference-release.aar
|
||||
|
||||
File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
|
||||
in the above aar, and it is transparent to the app as it will access them via
|
||||
equivalent java APIs.
|
||||
|
@ -1,105 +0,0 @@
|
||||
apply plugin: 'com.android.library'
|
||||
|
||||
// TensorFlow repo root dir on local machine
|
||||
def TF_SRC_DIR = projectDir.toString() + "/../../../.."
|
||||
|
||||
android {
|
||||
compileSdkVersion 24
|
||||
// Check local build_tools_version as this is liable to change within Android Studio.
|
||||
buildToolsVersion '25.0.2'
|
||||
|
||||
// for debugging native code purpose
|
||||
publishNonDefault true
|
||||
|
||||
defaultConfig {
|
||||
archivesBaseName = "Tensorflow-Android-Inference"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 23
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
ndk {
|
||||
abiFilters 'armeabi-v7a'
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments '-DANDROID_TOOLCHAIN=clang',
|
||||
'-DANDROID_STL=c++_static'
|
||||
}
|
||||
}
|
||||
}
|
||||
sourceSets {
|
||||
main {
|
||||
java {
|
||||
srcDir "${TF_SRC_DIR}/tensorflow/contrib/android/java"
|
||||
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
|
||||
exclude '**/examples/**'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path 'CMakeLists.txt'
|
||||
}
|
||||
}
|
||||
buildTypes {
|
||||
release {
|
||||
minifyEnabled false
|
||||
proguardFiles getDefaultProguardFile('proguard-android.txt'),
|
||||
'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build libtensorflow-core.a if necessary
|
||||
// Note: the environment needs to be set up already
|
||||
// [ such as installing autoconfig, make, etc ]
|
||||
// How to use:
|
||||
// 1) install all of the necessary tools to build libtensorflow-core.a
|
||||
// 2) inside Android Studio IDE, uncomment buildTensorFlow in
|
||||
// whenTaskAdded{...}
|
||||
// 3) re-sync and re-build. It could take a long time if NOT building
|
||||
// with multiple processes.
|
||||
import org.apache.tools.ant.taskdefs.condition.Os
|
||||
|
||||
Properties properties = new Properties()
|
||||
properties.load(project.rootProject.file('local.properties')
|
||||
.newDataInputStream())
|
||||
def ndkDir = properties.getProperty('ndk.dir')
|
||||
if (ndkDir == null || ndkDir == "") {
|
||||
ndkDir = System.getenv('ANDROID_NDK_HOME')
|
||||
}
|
||||
|
||||
if (!Os.isFamily(Os.FAMILY_WINDOWS)) {
|
||||
// This script is for non-Windows OS. For Windows OS, MANUALLY build
|
||||
// (or copy the built) libs/headers to the
|
||||
// ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
|
||||
// refer to CMakeLists.txt about lib and header directories for details
|
||||
task buildTensorflow(type: Exec) {
|
||||
group 'buildTensorflowLib'
|
||||
workingDir getProjectDir().toString() + '/../../../../'
|
||||
environment PATH: '/opt/local/bin:/opt/local/sbin:' +
|
||||
System.getenv('PATH')
|
||||
environment NDK_ROOT: ndkDir
|
||||
commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
|
||||
}
|
||||
|
||||
tasks.whenTaskAdded { task ->
|
||||
group 'buildTensorflowLib'
|
||||
if (task.name.toLowerCase().contains('sources')) {
|
||||
def tensorflowTarget = new File(getProjectDir().toString() +
|
||||
'/../../makefile/gen/lib/libtensorflow-core.a')
|
||||
if (!tensorflowTarget.exists()) {
|
||||
// Note:
|
||||
// just uncomment this line to use it:
|
||||
// it can take long time to build by default
|
||||
// it is disabled to avoid false first impression
|
||||
task.dependsOn buildTensorflow
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
compile fileTree(dir: 'libs', include: ['*.jar'])
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="org.tensorflow.contrib.android">
|
||||
|
||||
<uses-sdk
|
||||
android:minSdkVersion="4"
|
||||
android:targetSdkVersion="19" />
|
||||
|
||||
<application android:allowBackup="true" android:label="@string/app_name"
|
||||
android:supportsRtl="true">
|
||||
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -1,3 +0,0 @@
|
||||
<resources>
|
||||
<string name="app_name">TensorFlowInference</string>
|
||||
</resources>
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.contrib.android;
|
||||
|
||||
/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */
|
||||
public class RunStats implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation.
|
||||
*/
|
||||
public static byte[] runOptions() {
|
||||
return fullTraceRunOptions;
|
||||
}
|
||||
|
||||
public RunStats() {
|
||||
nativeHandle = allocate();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (nativeHandle != 0) {
|
||||
delete(nativeHandle);
|
||||
}
|
||||
nativeHandle = 0;
|
||||
}
|
||||
|
||||
/** Accumulate stats obtained when executing a graph. */
|
||||
public synchronized void add(byte[] runMetadata) {
|
||||
add(nativeHandle, runMetadata);
|
||||
}
|
||||
|
||||
/** Summary of the accumulated runtime stats. */
|
||||
public synchronized String summary() {
|
||||
return summary(nativeHandle);
|
||||
}
|
||||
|
||||
private long nativeHandle;
|
||||
|
||||
// Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends
|
||||
// up as.
|
||||
private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03};
|
||||
|
||||
private static native long allocate();
|
||||
|
||||
private static native void delete(long handle);
|
||||
|
||||
private static native void add(long handle, byte[] runMetadata);
|
||||
|
||||
private static native String summary(long handle);
|
||||
}
|
@ -1,650 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.contrib.android;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build.VERSION;
|
||||
import android.os.Trace;
|
||||
import android.text.TextUtils;
|
||||
import android.util.Log;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Operation;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.TensorFlow;
|
||||
import org.tensorflow.Tensors;
|
||||
import org.tensorflow.types.UInt8;
|
||||
|
||||
/**
|
||||
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
|
||||
* for inference.
|
||||
*
|
||||
* <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
|
||||
* example usage.
|
||||
*/
|
||||
public class TensorFlowInferenceInterface {
|
||||
private static final String TAG = "TensorFlowInferenceInterface";
|
||||
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
|
||||
|
||||
/*
|
||||
* Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
|
||||
*
|
||||
* @param assetManager The AssetManager to use to load the model file.
|
||||
* @param model The filepath to the GraphDef proto representing the model.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
this.modelName = model;
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
|
||||
InputStream is = null;
|
||||
try {
|
||||
String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
|
||||
is = assetManager.open(aname);
|
||||
} catch (IOException e) {
|
||||
if (hasAssetPrefix) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
// Perhaps the model file is not an asset but is on disk.
|
||||
try {
|
||||
is = new FileInputStream(model);
|
||||
} catch (IOException e2) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||
byte[] graphDef = new byte[is.available()];
|
||||
final int numBytesRead = is.read(graphDef);
|
||||
if (numBytesRead != graphDef.length) {
|
||||
throw new IOException(
|
||||
"read error: read only "
|
||||
+ numBytesRead
|
||||
+ " of the graph, expected to read "
|
||||
+ graphDef.length);
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
is.close();
|
||||
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Load a TensorFlow model from provided InputStream.
|
||||
* Note: The InputStream will not be closed after loading model, users need to
|
||||
* close it themselves.
|
||||
*
|
||||
* @param is The InputStream to use to load the model.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(InputStream is) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
// modelName is redundant for model loading from input stream, here is for
|
||||
// avoiding error in initialization as modelName is marked final.
|
||||
this.modelName = "";
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
try {
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||
int numBytesRead;
|
||||
byte[] buf = new byte[16384];
|
||||
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
|
||||
baos.write(buf, 0, numBytesRead);
|
||||
}
|
||||
byte[] graphDef = baos.toByteArray();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from the input stream", e);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Construct a TensorFlowInferenceInterface with provided Graph
|
||||
*
|
||||
* @param g The Graph to use to construct this interface.
|
||||
*/
|
||||
public TensorFlowInferenceInterface(Graph g) {
|
||||
prepareNativeRuntime();
|
||||
|
||||
// modelName is redundant here, here is for
|
||||
// avoiding error in initialization as modelName is marked final.
|
||||
this.modelName = "";
|
||||
this.g = g;
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||
*
|
||||
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||
*/
|
||||
public void run(String[] outputNames) {
|
||||
run(outputNames, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs inference between the previously registered input nodes (via feed*) and the requested
|
||||
* output nodes. Output nodes can then be queried with the fetch* methods.
|
||||
*
|
||||
* @param outputNames A list of output nodes which should be filled by the inference pass.
|
||||
*/
|
||||
public void run(String[] outputNames, boolean enableStats) {
|
||||
run(outputNames, enableStats, new String[] {});
|
||||
}
|
||||
|
||||
/** An overloaded version of runInference that allows supplying targetNodeNames as well */
|
||||
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
|
||||
// Release any Tensors from the previous run calls.
|
||||
closeFetches();
|
||||
|
||||
// Add fetches.
|
||||
for (String o : outputNames) {
|
||||
fetchNames.add(o);
|
||||
TensorId tid = TensorId.parse(o);
|
||||
runner.fetch(tid.name, tid.outputIndex);
|
||||
}
|
||||
|
||||
// Add targets.
|
||||
for (String t : targetNodeNames) {
|
||||
runner.addTarget(t);
|
||||
}
|
||||
|
||||
// Run the session.
|
||||
try {
|
||||
if (enableStats) {
|
||||
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
|
||||
fetchTensors = r.outputs;
|
||||
|
||||
if (runStats == null) {
|
||||
runStats = new RunStats();
|
||||
}
|
||||
runStats.add(r.metadata);
|
||||
} else {
|
||||
fetchTensors = runner.run();
|
||||
}
|
||||
} catch (RuntimeException e) {
|
||||
// Ideally the exception would have been let through, but since this interface predates the
|
||||
// TensorFlow Java API, must return -1.
|
||||
Log.e(
|
||||
TAG,
|
||||
"Failed to run TensorFlow inference with inputs:["
|
||||
+ TextUtils.join(", ", feedNames)
|
||||
+ "], outputs:["
|
||||
+ TextUtils.join(", ", fetchNames)
|
||||
+ "]");
|
||||
throw e;
|
||||
} finally {
|
||||
// Always release the feeds (to save resources) and reset the runner, this run is
|
||||
// over.
|
||||
closeFeeds();
|
||||
runner = sess.runner();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns a reference to the Graph describing the computation run during inference. */
|
||||
public Graph graph() {
|
||||
return g;
|
||||
}
|
||||
|
||||
public Operation graphOperation(String operationName) {
|
||||
final Operation operation = g.operation(operationName);
|
||||
if (operation == null) {
|
||||
throw new RuntimeException(
|
||||
"Node '" + operationName + "' does not exist in model '" + modelName + "'");
|
||||
}
|
||||
return operation;
|
||||
}
|
||||
|
||||
/** Returns the last stat summary string if logging is enabled. */
|
||||
public String getStatString() {
|
||||
return (runStats == null) ? "" : runStats.summary();
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up the state associated with this Object.
|
||||
*
|
||||
* <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
|
||||
*/
|
||||
public void close() {
|
||||
closeFeeds();
|
||||
closeFetches();
|
||||
sess.close();
|
||||
g.close();
|
||||
if (runStats != null) {
|
||||
runStats.close();
|
||||
}
|
||||
runStats = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
try {
|
||||
close();
|
||||
} finally {
|
||||
super.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
// Methods for taking a native Tensor and filling it with values from Java arrays.
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, boolean[] src, long... dims) {
|
||||
byte[] b = new byte[src.length];
|
||||
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
b[i] = src[i] ? (byte) 1 : (byte) 0;
|
||||
}
|
||||
|
||||
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, float[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, int[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, long[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, double[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
|
||||
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
|
||||
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, byte[] src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
|
||||
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
|
||||
* a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[] src) {
|
||||
addFeed(inputName, Tensors.create(src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
|
||||
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
|
||||
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
|
||||
*/
|
||||
public void feedString(String inputName, byte[][] src) {
|
||||
addFeed(inputName, Tensors.create(src));
|
||||
}
|
||||
|
||||
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, FloatBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, IntBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, LongBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, DoubleBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
|
||||
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
|
||||
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
|
||||
* elements as that of the destination Tensor. If {@link src} has more elements than the
|
||||
* destination has capacity, the copy is truncated.
|
||||
*/
|
||||
public void feed(String inputName, ByteBuffer src, long... dims) {
|
||||
addFeed(inputName, Tensor.create(UInt8.class, dims, src));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, float[] dst) {
|
||||
fetch(outputName, FloatBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, int[] dst) {
|
||||
fetch(outputName, IntBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, long[] dst) {
|
||||
fetch(outputName, LongBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, double[] dst) {
|
||||
fetch(outputName, DoubleBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
|
||||
* dst} must have length greater than or equal to that of the source Tensor. This operation will
|
||||
* not affect dst's content past the source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, byte[] dst) {
|
||||
fetch(outputName, ByteBuffer.wrap(dst));
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, FloatBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, IntBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, LongBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, DoubleBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
|
||||
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
|
||||
* or equal to that of the source Tensor. This operation will not affect dst's content past the
|
||||
* source Tensor's size.
|
||||
*/
|
||||
public void fetch(String outputName, ByteBuffer dst) {
|
||||
getTensor(outputName).writeTo(dst);
|
||||
}
|
||||
|
||||
private void prepareNativeRuntime() {
|
||||
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
|
||||
try {
|
||||
// Hack to see if the native libraries have been loaded.
|
||||
new RunStats();
|
||||
Log.i(TAG, "TensorFlow native methods already loaded");
|
||||
} catch (UnsatisfiedLinkError e1) {
|
||||
Log.i(
|
||||
TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
|
||||
try {
|
||||
System.loadLibrary("tensorflow_inference");
|
||||
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
|
||||
} catch (UnsatisfiedLinkError e2) {
|
||||
throw new RuntimeException(
|
||||
"Native TF methods not found; check that the correct native"
|
||||
+ " libraries are present in the APK.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
|
||||
final long startMs = System.currentTimeMillis();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("importGraphDef");
|
||||
}
|
||||
|
||||
try {
|
||||
g.importGraphDef(graphDef);
|
||||
} catch (IllegalArgumentException e) {
|
||||
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // importGraphDef.
|
||||
}
|
||||
|
||||
final long endMs = System.currentTimeMillis();
|
||||
Log.i(
|
||||
TAG,
|
||||
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
|
||||
}
|
||||
|
||||
private void addFeed(String inputName, Tensor<?> t) {
|
||||
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
|
||||
TensorId tid = TensorId.parse(inputName);
|
||||
runner.feed(tid.name, tid.outputIndex, t);
|
||||
feedNames.add(inputName);
|
||||
feedTensors.add(t);
|
||||
}
|
||||
|
||||
private static class TensorId {
|
||||
String name;
|
||||
int outputIndex;
|
||||
|
||||
// Parse output names into a TensorId.
|
||||
//
|
||||
// E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
|
||||
public static TensorId parse(String name) {
|
||||
TensorId tid = new TensorId();
|
||||
int colonIndex = name.lastIndexOf(':');
|
||||
if (colonIndex < 0) {
|
||||
tid.outputIndex = 0;
|
||||
tid.name = name;
|
||||
return tid;
|
||||
}
|
||||
try {
|
||||
tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
|
||||
tid.name = name.substring(0, colonIndex);
|
||||
} catch (NumberFormatException e) {
|
||||
tid.outputIndex = 0;
|
||||
tid.name = name;
|
||||
}
|
||||
return tid;
|
||||
}
|
||||
}
|
||||
|
||||
private Tensor<?> getTensor(String outputName) {
|
||||
int i = 0;
|
||||
for (String n : fetchNames) {
|
||||
if (n.equals(outputName)) {
|
||||
return fetchTensors.get(i);
|
||||
}
|
||||
++i;
|
||||
}
|
||||
throw new RuntimeException(
|
||||
"Node '" + outputName + "' was not provided to run(), so it cannot be read");
|
||||
}
|
||||
|
||||
private void closeFeeds() {
|
||||
for (Tensor<?> t : feedTensors) {
|
||||
t.close();
|
||||
}
|
||||
feedTensors.clear();
|
||||
feedNames.clear();
|
||||
}
|
||||
|
||||
private void closeFetches() {
|
||||
for (Tensor<?> t : fetchTensors) {
|
||||
t.close();
|
||||
}
|
||||
fetchTensors.clear();
|
||||
fetchNames.clear();
|
||||
}
|
||||
|
||||
// Immutable state.
|
||||
private final String modelName;
|
||||
private final Graph g;
|
||||
private final Session sess;
|
||||
|
||||
// State reset on every call to run.
|
||||
private Session.Runner runner;
|
||||
private List<String> feedNames = new ArrayList<String>();
|
||||
private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
|
||||
private List<String> fetchNames = new ArrayList<String>();
|
||||
private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
|
||||
|
||||
// Mutable state.
|
||||
private RunStats runStats;
|
||||
}
|
@ -1,83 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/android/jni/run_stats_jni.h"
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/stat_summarizer.h"
|
||||
|
||||
using tensorflow::RunMetadata;
|
||||
using tensorflow::StatSummarizer;
|
||||
|
||||
namespace {
|
||||
StatSummarizer* requireHandle(JNIEnv* env, jlong handle) {
|
||||
if (handle == 0) {
|
||||
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
|
||||
"close() has been called on the RunStats object");
|
||||
return nullptr;
|
||||
}
|
||||
return reinterpret_cast<StatSummarizer*>(handle);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define RUN_STATS_METHOD(name) \
|
||||
JNICALL Java_org_tensorflow_contrib_android_RunStats_##name
|
||||
|
||||
JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) {
|
||||
static_assert(sizeof(jlong) >= sizeof(StatSummarizer*),
|
||||
"Cannot package C++ object pointers as a Java long");
|
||||
tensorflow::StatSummarizerOptions opts;
|
||||
return reinterpret_cast<jlong>(new StatSummarizer(opts));
|
||||
}
|
||||
|
||||
JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz,
|
||||
jlong handle) {
|
||||
if (handle == 0) return;
|
||||
delete reinterpret_cast<StatSummarizer*>(handle);
|
||||
}
|
||||
|
||||
JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle,
|
||||
jbyteArray run_metadata) {
|
||||
StatSummarizer* s = requireHandle(env, handle);
|
||||
if (s == nullptr) return;
|
||||
jbyte* data = env->GetByteArrayElements(run_metadata, nullptr);
|
||||
int size = static_cast<int>(env->GetArrayLength(run_metadata));
|
||||
tensorflow::RunMetadata proto;
|
||||
if (!proto.ParseFromArray(data, size)) {
|
||||
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
|
||||
"runMetadata does not seem to be a serialized RunMetadata "
|
||||
"protocol message");
|
||||
} else if (proto.has_step_stats()) {
|
||||
s->ProcessStepStats(proto.step_stats());
|
||||
}
|
||||
env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
|
||||
jlong handle) {
|
||||
StatSummarizer* s = requireHandle(env, handle);
|
||||
if (s == nullptr) return nullptr;
|
||||
std::stringstream ret;
|
||||
ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
|
||||
10)
|
||||
<< s->GetStatsByNodeType() << s->ShortSummary();
|
||||
return env->NewStringUTF(ret.str().c_str());
|
||||
}
|
||||
|
||||
#undef RUN_STATS_METHOD
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||
#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
#define RUN_STATS_METHOD(name) \
|
||||
Java_org_tensorflow_contrib_android_RunStats_##name
|
||||
|
||||
JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass);
|
||||
JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong);
|
||||
JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong,
|
||||
jbyteArray);
|
||||
JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong);
|
||||
|
||||
#undef RUN_STATS_METHOD
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
|
@ -1,11 +0,0 @@
|
||||
VERS_1.0 {
|
||||
# Export JNI symbols.
|
||||
global:
|
||||
Java_*;
|
||||
JNI_OnLoad;
|
||||
JNI_OnUnload;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
*;
|
||||
};
|
@ -1,29 +0,0 @@
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "autograph",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
# This module is kept for backward compatibility only. To depend on AutoGraph,
|
||||
# use //third_party/tensorflow/python/autograph instead.
|
||||
deps = [
|
||||
"//tensorflow/python/autograph",
|
||||
],
|
||||
)
|
@ -1,9 +0,0 @@
|
||||
# AutoGraph
|
||||
|
||||
**NOTE: As tensorflow.contrib is being
|
||||
[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
|
||||
moving into TensorFlow core.
|
||||
|
||||
The new code location is `tensorflow/python/autograph`. Please refer to the
|
||||
README.md file in that directory.
|
||||
**
|
@ -1,24 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""This is the legacy module for AutoGraph, kept for backward compatibility.
|
||||
|
||||
New users should instead use `tensorflow.python.autograph`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph import * # pylint:disable=wildcard-import
|
@ -1,39 +0,0 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "benchmark_base",
|
||||
srcs = [
|
||||
"benchmark_base.py",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cartpole_benchmark",
|
||||
size = "enormous",
|
||||
srcs = ["cartpole_benchmark.py"],
|
||||
python_version = "PY2",
|
||||
tags = [
|
||||
"local",
|
||||
"manual",
|
||||
"no_oss",
|
||||
"notap",
|
||||
"nozapfhahn",
|
||||
],
|
||||
deps = [
|
||||
":benchmark_base",
|
||||
# Note: required gym dependency may need to be added here.
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_logged_benchmark(
|
||||
name = "cartpole_logged_benchmark",
|
||||
target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark",
|
||||
)
|
@ -1,62 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common benchmarking code.
|
||||
|
||||
See https://www.tensorflow.org/community/benchmarks for usage.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ReportingBenchmark(tf.test.Benchmark):
|
||||
"""Base class for a benchmark that reports general performance metrics.
|
||||
|
||||
Subclasses only need to call one of the _profile methods, and optionally
|
||||
report_results.
|
||||
"""
|
||||
|
||||
def time_execution(self, name, target, iters, warm_up_iters=5):
|
||||
for _ in range(warm_up_iters):
|
||||
target()
|
||||
|
||||
all_times = []
|
||||
for _ in range(iters):
|
||||
iter_time = time.time()
|
||||
target()
|
||||
all_times.append(time.time() - iter_time)
|
||||
|
||||
avg_time = np.average(all_times)
|
||||
|
||||
extras = {}
|
||||
extras['all_times'] = all_times
|
||||
|
||||
if isinstance(name, tuple):
|
||||
extras['name'] = name
|
||||
name = '_'.join(str(piece) for piece in name)
|
||||
|
||||
self.report_benchmark(
|
||||
iters=iters, wall_time=avg_time, name=name, extras=extras)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -1,492 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A basic RL cartpole benchmark.
|
||||
|
||||
The RL model uses the OpenAI Gym environment to train a simple network using
|
||||
the policy gradients method. The training scales the gradients for each step
|
||||
by the episode's cumulative discounted reward and averages these gradients over
|
||||
a fixed number of games before applying the optimization step.
|
||||
|
||||
For benchmarking purposes, we replace the OpenAI Gym environment to a fake
|
||||
that returns random actions and rewards and never ends the episode. This way
|
||||
the benchmarks compare the same amount of computation at each step.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib import eager
|
||||
from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base
|
||||
from tensorflow.python import autograph as ag
|
||||
from tensorflow.python.eager import context
|
||||
|
||||
#
|
||||
# AutoGraph implementation
|
||||
#
|
||||
|
||||
|
||||
@ag.convert()
|
||||
def graph_append_discounted_rewards(destination, rewards, discount_rate):
|
||||
"""Discounts episode rewards and appends them to destination."""
|
||||
ag.set_element_type(rewards, tf.float32)
|
||||
|
||||
cdr = 0.0
|
||||
reverse_discounted = []
|
||||
ag.set_element_type(reverse_discounted, tf.float32)
|
||||
|
||||
for i in range(len(rewards) - 1, -1, -1):
|
||||
cdr = cdr * discount_rate + rewards[i]
|
||||
cdr.set_shape(())
|
||||
reverse_discounted.append(cdr)
|
||||
|
||||
retval = destination
|
||||
# Note: AutoGraph doesn't yet support reversed() so we use a loop instead.
|
||||
for i in range(len(reverse_discounted) - 1, -1, -1):
|
||||
retval.append(reverse_discounted[i])
|
||||
|
||||
return retval
|
||||
|
||||
|
||||
class GraphPolicyNetwork(tf.keras.Model):
|
||||
"""Policy network for the cart-pole reinforcement learning problem.
|
||||
|
||||
The forward path of the network takes an observation from the cart-pole
|
||||
environment (length-4 vector) and outputs an action.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super(GraphPolicyNetwork, self).__init__()
|
||||
self._hidden_layer = tf.keras.layers.Dense(
|
||||
hidden_size, activation=tf.nn.elu)
|
||||
self._output_layer = tf.keras.layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Calculates logits and action.
|
||||
|
||||
Args:
|
||||
inputs: Observations from a step in the cart-pole environment, of shape
|
||||
`(batch_size, input_size)`
|
||||
|
||||
Returns:
|
||||
logits: the logits output by the output layer. This can be viewed as the
|
||||
likelihood vales of choosing the left (0) action. Shape:
|
||||
`(batch_size, 1)`.
|
||||
actions: randomly selected actions ({0, 1}) based on the logits. Shape:
|
||||
`(batch_size, 1)`.
|
||||
"""
|
||||
hidden = self._hidden_layer(inputs)
|
||||
logits = self._output_layer(hidden)
|
||||
|
||||
left_prob = tf.nn.sigmoid(logits)
|
||||
action_probs = tf.concat([left_prob, 1.0 - left_prob], 1)
|
||||
|
||||
actions = tf.multinomial(tf.log(action_probs), 1)
|
||||
return logits, actions
|
||||
|
||||
# TODO(mdan): Move this method out of the class.
|
||||
@ag.convert()
|
||||
def train(self, cart_pole_env, optimizer, discount_rate, num_games,
|
||||
max_steps_per_game):
|
||||
var_list = tf.trainable_variables()
|
||||
grad_list = [
|
||||
tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list
|
||||
]
|
||||
|
||||
step_counts = []
|
||||
discounted_rewards = []
|
||||
ag.set_element_type(discounted_rewards, tf.float32)
|
||||
ag.set_element_type(step_counts, tf.int32)
|
||||
|
||||
# Note: we use a shared object, cart_pole_env here. Because calls to the
|
||||
# object's method are made through py_func, TensorFlow cannot detect its
|
||||
# data dependencies. Hence we must manually synchronize access to it
|
||||
# and ensure the control dependencies are set in such a way that
|
||||
# calls to reset(), take_one_step, etc. are made in the correct order.
|
||||
sync_counter = tf.constant(0)
|
||||
|
||||
for _ in tf.range(num_games):
|
||||
with tf.control_dependencies([sync_counter]):
|
||||
obs = cart_pole_env.reset()
|
||||
with tf.control_dependencies([obs]):
|
||||
sync_counter += 1
|
||||
|
||||
game_rewards = []
|
||||
ag.set_element_type(game_rewards, tf.float32)
|
||||
|
||||
for step in tf.range(max_steps_per_game):
|
||||
logits, actions = self(obs) # pylint:disable=not-callable
|
||||
logits = tf.reshape(logits, ())
|
||||
actions = tf.reshape(actions, ())
|
||||
|
||||
labels = 1.0 - tf.cast(actions, tf.float32)
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits)
|
||||
grads = tf.gradients(loss, var_list)
|
||||
|
||||
for i in range(len(grads)):
|
||||
grad_list[i].append(grads[i])
|
||||
|
||||
with tf.control_dependencies([sync_counter]):
|
||||
obs, reward, done = cart_pole_env.step(actions)
|
||||
with tf.control_dependencies([obs]):
|
||||
sync_counter += 1
|
||||
obs = tf.reshape(obs, (1, 4))
|
||||
|
||||
game_rewards.append(reward)
|
||||
if reward < 0.1 or done:
|
||||
step_counts.append(step + 1)
|
||||
break
|
||||
|
||||
discounted_rewards = graph_append_discounted_rewards(
|
||||
discounted_rewards, game_rewards, discount_rate)
|
||||
|
||||
discounted_rewards = ag.stack(discounted_rewards)
|
||||
discounted_rewards.set_shape((None,))
|
||||
mean, variance = tf.nn.moments(discounted_rewards, [0])
|
||||
normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance)
|
||||
|
||||
for i in range(len(grad_list)):
|
||||
g = ag.stack(grad_list[i])
|
||||
|
||||
# This block just adjusts the shapes to match for multiplication.
|
||||
r = normalized_rewards
|
||||
if r.shape.ndims < g.shape.ndims:
|
||||
r = tf.expand_dims(r, -1)
|
||||
if r.shape.ndims < g.shape.ndims:
|
||||
r = tf.expand_dims(r, -1)
|
||||
|
||||
grad_list[i] = tf.reduce_mean(g * r, axis=0)
|
||||
|
||||
optimizer.apply_gradients(
|
||||
zip(grad_list, var_list), global_step=tf.train.get_global_step())
|
||||
|
||||
return ag.stack(step_counts)
|
||||
|
||||
|
||||
@ag.convert()
|
||||
def graph_train_model(policy_network, cart_pole_env, optimizer, iterations):
|
||||
"""Trains the policy network for a given number of iterations."""
|
||||
i = tf.constant(0)
|
||||
mean_steps_per_iteration = []
|
||||
ag.set_element_type(mean_steps_per_iteration, tf.int32)
|
||||
|
||||
while i < iterations:
|
||||
steps_per_game = policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate=0.95,
|
||||
num_games=20,
|
||||
max_steps_per_game=200)
|
||||
mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game))
|
||||
i += 1
|
||||
|
||||
return ag.stack(mean_steps_per_iteration)
|
||||
|
||||
|
||||
class GraphGymCartpoleEnv(object):
|
||||
"""An env backed by OpenAI Gym's CartPole environment.
|
||||
|
||||
Used to confirm a functional model only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
cart_pole_env = gym.make('CartPole-v1')
|
||||
cart_pole_env.seed(0)
|
||||
cart_pole_env.reset()
|
||||
self.env = cart_pole_env
|
||||
|
||||
def reset(self):
|
||||
obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ())
|
||||
obs = tf.reshape(obs, (1, 4))
|
||||
obs = tf.cast(obs, tf.float32)
|
||||
return obs
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
def take_one_step(actions):
|
||||
obs, reward, done, _ = self.env.step(actions)
|
||||
obs = obs.astype(np.float32)
|
||||
reward = np.float32(reward)
|
||||
return obs, reward, done
|
||||
|
||||
return ag.utils.wrap_py_func(take_one_step,
|
||||
(tf.float32, tf.float32, tf.bool), (actions,))
|
||||
|
||||
|
||||
class GraphRandomCartpoleEnv(object):
|
||||
"""An environment that returns random actions and never finishes.
|
||||
|
||||
Used during benchmarking, it will cause training to run a constant number of
|
||||
steps.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
return tf.random.normal((1, 4))
|
||||
|
||||
def step(self, actions):
|
||||
with tf.control_dependencies([actions]):
|
||||
random_obs = tf.random.normal((1, 4))
|
||||
fixed_reward = tf.constant(0.001)
|
||||
done = tf.constant(False)
|
||||
return random_obs, fixed_reward, done
|
||||
|
||||
|
||||
#
|
||||
# Eager implementation
|
||||
#
|
||||
|
||||
|
||||
def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate):
|
||||
cdr = 0.0
|
||||
reverse_discounted = []
|
||||
|
||||
for i in range(len(rewards) - 1, -1, -1):
|
||||
cdr = cdr * discount_rate + rewards[i]
|
||||
reverse_discounted.append(cdr)
|
||||
|
||||
discounted_rewards.extend(reversed(reverse_discounted))
|
||||
return discounted_rewards
|
||||
|
||||
|
||||
class EagerPolicyNetwork(tf.keras.Model):
|
||||
"""Policy network for the cart-pole reinforcement learning problem.
|
||||
|
||||
The forward path of the network takes an observation from the cart-pole
|
||||
environment (length-4 vector) and outputs an action.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super(EagerPolicyNetwork, self).__init__()
|
||||
self._hidden_layer = tf.keras.layers.Dense(
|
||||
hidden_size, activation=tf.nn.elu)
|
||||
self._output_layer = tf.keras.layers.Dense(1)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Calculates logits and action.
|
||||
|
||||
Args:
|
||||
inputs: Observations from a step in the cart-pole environment, of shape
|
||||
`(batch_size, input_size)`
|
||||
|
||||
Returns:
|
||||
logits: the logits output by the output layer. This can be viewed as the
|
||||
likelihood vales of choosing the left (0) action. Shape:
|
||||
`(batch_size, 1)`.
|
||||
actions: randomly selected actions ({0, 1}) based on the logits. Shape:
|
||||
`(batch_size, 1)`.
|
||||
"""
|
||||
hidden = self._hidden_layer(inputs)
|
||||
logits = self._output_layer(hidden)
|
||||
|
||||
left_prob = tf.nn.sigmoid(logits)
|
||||
action_probs = tf.concat([left_prob, 1.0 - left_prob], 1)
|
||||
|
||||
self._grad_fn = eager.implicit_gradients(
|
||||
self._get_cross_entropy_and_save_actions)
|
||||
|
||||
actions = tf.multinomial(tf.log(action_probs), 1)
|
||||
return logits, actions
|
||||
|
||||
def _get_cross_entropy_and_save_actions(self, inputs):
|
||||
logits, actions = self(inputs) # pylint:disable=not-callable
|
||||
self._current_actions = actions
|
||||
labels = 1.0 - tf.cast(actions, tf.float32)
|
||||
return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
|
||||
|
||||
def train(self, cart_pole_env, optimizer, discount_rate, num_games,
|
||||
max_steps_per_game):
|
||||
grad_list = None
|
||||
|
||||
step_counts = []
|
||||
discounted_rewards = []
|
||||
|
||||
for _ in range(num_games):
|
||||
obs = cart_pole_env.reset()
|
||||
|
||||
game_rewards = []
|
||||
|
||||
for step in range(max_steps_per_game):
|
||||
grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32))
|
||||
grads, var_list = zip(*grads_and_vars)
|
||||
actions = self._current_actions.numpy()[0][0]
|
||||
|
||||
if grad_list is None:
|
||||
grad_list = [[g] for g in grads]
|
||||
else:
|
||||
for i in range(len(grads)):
|
||||
grad_list[i].append(grads[i])
|
||||
|
||||
obs, reward, done = cart_pole_env.step(actions)
|
||||
|
||||
game_rewards.append(reward)
|
||||
if reward < 0.1 or done:
|
||||
step_counts.append(step + 1)
|
||||
break
|
||||
|
||||
discounted_rewards = eager_append_discounted_rewards(
|
||||
discounted_rewards, game_rewards, discount_rate)
|
||||
|
||||
discounted_rewards = tf.stack(discounted_rewards)
|
||||
mean, variance = tf.nn.moments(discounted_rewards, [0])
|
||||
normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance)
|
||||
|
||||
for i in range(len(grad_list)):
|
||||
g = tf.stack(grad_list[i])
|
||||
|
||||
r = normalized_rewards
|
||||
while r.shape.ndims < g.shape.ndims:
|
||||
r = tf.expand_dims(r, -1)
|
||||
|
||||
grad_list[i] = tf.reduce_mean(g * r, axis=0)
|
||||
|
||||
optimizer.apply_gradients(
|
||||
zip(grad_list, var_list), global_step=tf.train.get_global_step())
|
||||
|
||||
return tf.stack(step_counts)
|
||||
|
||||
|
||||
def eager_train_model(policy_network, cart_pole_env, optimizer, iterations):
|
||||
"""Trains the policy network for a given number of iterations."""
|
||||
mean_steps_per_iteration = []
|
||||
|
||||
for _ in range(iterations):
|
||||
steps_per_game = policy_network.train(
|
||||
cart_pole_env,
|
||||
optimizer,
|
||||
discount_rate=0.95,
|
||||
num_games=20,
|
||||
max_steps_per_game=200)
|
||||
mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game))
|
||||
|
||||
return mean_steps_per_iteration
|
||||
|
||||
|
||||
class EagerGymCartpoleEnv(object):
|
||||
"""An env backed by OpenAI Gym's CartPole environment.
|
||||
|
||||
Used to confirm a functional model only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
cart_pole_env = gym.make('CartPole-v1')
|
||||
cart_pole_env.seed(0)
|
||||
cart_pole_env.reset()
|
||||
self.env = cart_pole_env
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, actions):
|
||||
obs, reward, done, _ = self.env.step(actions)
|
||||
return obs, reward, done
|
||||
|
||||
|
||||
class EagerRandomCartpoleEnv(object):
|
||||
"""An environment that returns random actions and never finishes.
|
||||
|
||||
Used during benchmarking, it will cause training to run a constant number of
|
||||
steps.
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
return np.random.normal(size=(4,))
|
||||
|
||||
def step(self, actions):
|
||||
with tf.control_dependencies([actions]):
|
||||
random_obs = np.random.normal(size=(4,))
|
||||
fixed_reward = 0.001
|
||||
done = False
|
||||
return random_obs, fixed_reward, done
|
||||
|
||||
|
||||
def graph_demo_training():
|
||||
"""Not used in the benchmark. Used to confirm a functional model."""
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
|
||||
network = GraphPolicyNetwork(hidden_size=5)
|
||||
network.build((1, 4))
|
||||
env = GraphGymCartpoleEnv()
|
||||
opt = tf.train.AdamOptimizer(0.05)
|
||||
|
||||
train_ops = graph_train_model(network, env, opt, iterations=5)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sess.run(tf.local_variables_initializer())
|
||||
steps_per_iteration = sess.run(train_ops)
|
||||
for i, steps in enumerate(steps_per_iteration):
|
||||
print('Step {} iterations: {}'.format(i, steps))
|
||||
|
||||
|
||||
def eager_demo_training():
|
||||
with context.eager_mode():
|
||||
network = EagerPolicyNetwork(hidden_size=5)
|
||||
network.build((1, 4))
|
||||
env = EagerGymCartpoleEnv()
|
||||
opt = tf.train.AdamOptimizer(0.05)
|
||||
|
||||
steps_per_iteration = eager_train_model(network, env, opt, iterations=5)
|
||||
for i, steps in enumerate(steps_per_iteration):
|
||||
print('Step {} iterations: {}'.format(i, steps))
|
||||
|
||||
|
||||
class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark):
|
||||
"""Actual benchmark.
|
||||
|
||||
Trains the RL agent a fixed number of times, on random environments that
|
||||
result in constant number of steps.
|
||||
"""
|
||||
|
||||
def benchmark_cartpole(self):
|
||||
|
||||
def train_session(sess, ops):
|
||||
return lambda: sess.run(ops)
|
||||
|
||||
def train_eager(network, env, opt):
|
||||
return lambda: eager_train_model(network, env, opt, iterations=10)
|
||||
|
||||
for model_size in (10, 100, 1000):
|
||||
with tf.Graph().as_default():
|
||||
network = GraphPolicyNetwork(hidden_size=model_size)
|
||||
network.build((1, 4))
|
||||
env = GraphRandomCartpoleEnv()
|
||||
opt = tf.train.AdamOptimizer(0.05)
|
||||
train_ops = graph_train_model(network, env, opt, iterations=10)
|
||||
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sess.run(tf.local_variables_initializer())
|
||||
|
||||
self.time_execution(('cartpole', 'autograph', model_size),
|
||||
train_session(sess, train_ops), 20)
|
||||
|
||||
with context.eager_mode():
|
||||
network = EagerPolicyNetwork(hidden_size=model_size)
|
||||
network.build((1, 4))
|
||||
env = EagerRandomCartpoleEnv()
|
||||
opt = tf.train.AdamOptimizer(0.05)
|
||||
|
||||
self.time_execution(('cartpole', 'eager', model_size),
|
||||
train_eager(network, env, opt), 20)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
File diff suppressed because one or more lines are too long
@ -1,652 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "etTmZVFN8fYO"
|
||||
},
|
||||
"source": [
|
||||
"This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "eqOvRhOz8SWs"
|
||||
},
|
||||
"source": [
|
||||
"### Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "nHY0tntRizGb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q tf-nightly"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "Pa2qpEmoVOGe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gzip\n",
|
||||
"import os\n",
|
||||
"import shutil\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import six\n",
|
||||
"from six.moves import urllib\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"from tensorflow.contrib import autograph as ag\n",
|
||||
"from tensorflow.contrib.eager.python import tfe\n",
|
||||
"from tensorflow.python.eager import context\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "PZWxEJFM9A7b"
|
||||
},
|
||||
"source": [
|
||||
"### Testing boilerplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "kfZk9EFZ5TeQ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Test-only parameters. Test checks successful completion not correctness. \n",
|
||||
"burn_ins = 1\n",
|
||||
"trials = 1\n",
|
||||
"max_steps = 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "k0GKbZBJ9Gt9"
|
||||
},
|
||||
"source": [
|
||||
"### Speed test configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "gWXV8WHn43iZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@test {\"skip\": true} \n",
|
||||
"burn_ins = 3\n",
|
||||
"trials = 10\n",
|
||||
"max_steps = 500\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "kZV_3pGy8033"
|
||||
},
|
||||
"source": [
|
||||
"### Data source setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "YfnHJbBOBKae"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def download(directory, filename):\n",
|
||||
" filepath = os.path.join(directory, filename)\n",
|
||||
" if tf.gfile.Exists(filepath):\n",
|
||||
" return filepath\n",
|
||||
" if not tf.gfile.Exists(directory):\n",
|
||||
" tf.gfile.MakeDirs(directory)\n",
|
||||
" url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
|
||||
" zipped_filepath = filepath + '.gz'\n",
|
||||
" print('Downloading %s to %s' % (url, zipped_filepath))\n",
|
||||
" urllib.request.urlretrieve(url, zipped_filepath)\n",
|
||||
" with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
|
||||
" shutil.copyfileobj(f_in, f_out)\n",
|
||||
" os.remove(zipped_filepath)\n",
|
||||
" return filepath\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def dataset(directory, images_file, labels_file):\n",
|
||||
" images_file = download(directory, images_file)\n",
|
||||
" labels_file = download(directory, labels_file)\n",
|
||||
"\n",
|
||||
" def decode_image(image):\n",
|
||||
" # Normalize from [0, 255] to [0.0, 1.0]\n",
|
||||
" image = tf.decode_raw(image, tf.uint8)\n",
|
||||
" image = tf.cast(image, tf.float32)\n",
|
||||
" image = tf.reshape(image, [784])\n",
|
||||
" return image / 255.0\n",
|
||||
"\n",
|
||||
" def decode_label(label):\n",
|
||||
" label = tf.decode_raw(label, tf.uint8)\n",
|
||||
" label = tf.reshape(label, [])\n",
|
||||
" return tf.to_int32(label)\n",
|
||||
"\n",
|
||||
" images = tf.data.FixedLengthRecordDataset(\n",
|
||||
" images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
|
||||
" labels = tf.data.FixedLengthRecordDataset(\n",
|
||||
" labels_file, 1, header_bytes=8).map(decode_label)\n",
|
||||
" return tf.data.Dataset.zip((images, labels))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def mnist_train(directory):\n",
|
||||
" return dataset(directory, 'train-images-idx3-ubyte',\n",
|
||||
" 'train-labels-idx1-ubyte')\n",
|
||||
"\n",
|
||||
"def mnist_test(directory):\n",
|
||||
" return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n",
|
||||
"\n",
|
||||
"def setup_mnist_data(is_training, hp, batch_size):\n",
|
||||
" if is_training:\n",
|
||||
" ds = mnist_train('/tmp/autograph_mnist_data')\n",
|
||||
" ds = ds.cache()\n",
|
||||
" ds = ds.shuffle(batch_size * 10)\n",
|
||||
" else:\n",
|
||||
" ds = mnist_test('/tmp/autograph_mnist_data')\n",
|
||||
" ds = ds.cache()\n",
|
||||
" ds = ds.repeat()\n",
|
||||
" ds = ds.batch(batch_size)\n",
|
||||
" return ds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "qzkZyZcS9THu"
|
||||
},
|
||||
"source": [
|
||||
"### Keras model definition"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "x_MU13boiok2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def mlp_model(input_shape):\n",
|
||||
" model = tf.keras.Sequential((\n",
|
||||
" tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
|
||||
" tf.keras.layers.Dense(100, activation='relu'),\n",
|
||||
" tf.keras.layers.Dense(10, activation='softmax')))\n",
|
||||
" model.build()\n",
|
||||
" return model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "DXt4GoTxtvn2"
|
||||
},
|
||||
"source": [
|
||||
"# AutoGraph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "W51sfbONiz_5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict(m, x, y):\n",
|
||||
" y_p = m(x)\n",
|
||||
" losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
|
||||
" l = tf.reduce_mean(losses)\n",
|
||||
" accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
|
||||
" accuracy = tf.reduce_mean(accuracies)\n",
|
||||
" return l, accuracy\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "CsAD0ajbi9iZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def fit(m, x, y, opt):\n",
|
||||
" l, accuracy = predict(m, x, y)\n",
|
||||
" opt.minimize(l)\n",
|
||||
" return l, accuracy\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "RVw57HdTjPzi"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_next_batch(ds):\n",
|
||||
" itr = ds.make_one_shot_iterator()\n",
|
||||
" image, label = itr.get_next()\n",
|
||||
" x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
|
||||
" y = tf.one_hot(tf.squeeze(label), 10)\n",
|
||||
" return x, y\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "UUI0566FjZPx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(train_ds, test_ds, hp):\n",
|
||||
" m = mlp_model((28 * 28,))\n",
|
||||
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
|
||||
"\n",
|
||||
" train_losses = []\n",
|
||||
" test_losses = []\n",
|
||||
" train_accuracies = []\n",
|
||||
" test_accuracies = []\n",
|
||||
" ag.set_element_type(train_losses, tf.float32)\n",
|
||||
" ag.set_element_type(test_losses, tf.float32)\n",
|
||||
" ag.set_element_type(train_accuracies, tf.float32)\n",
|
||||
" ag.set_element_type(test_accuracies, tf.float32)\n",
|
||||
"\n",
|
||||
" i = tf.constant(0)\n",
|
||||
" while i \u003c hp.max_steps:\n",
|
||||
" train_x, train_y = get_next_batch(train_ds)\n",
|
||||
" test_x, test_y = get_next_batch(test_ds)\n",
|
||||
" step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
|
||||
" step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
|
||||
"\n",
|
||||
" train_losses.append(step_train_loss)\n",
|
||||
" test_losses.append(step_test_loss)\n",
|
||||
" train_accuracies.append(step_train_accuracy)\n",
|
||||
" test_accuracies.append(step_test_accuracy)\n",
|
||||
"\n",
|
||||
" i += 1\n",
|
||||
" return (ag.stack(train_losses), ag.stack(test_losses),\n",
|
||||
" ag.stack(train_accuracies), ag.stack(test_accuracies))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 215
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 12156,
|
||||
"status": "ok",
|
||||
"timestamp": 1531752050611,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "K1m8TwOKjdNd",
|
||||
"outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"('Duration:', 0.6226680278778076)\n",
|
||||
"('Duration:', 0.6082069873809814)\n",
|
||||
"('Duration:', 0.6223258972167969)\n",
|
||||
"('Duration:', 0.6176440715789795)\n",
|
||||
"('Duration:', 0.6309840679168701)\n",
|
||||
"('Duration:', 0.6180410385131836)\n",
|
||||
"('Duration:', 0.6219630241394043)\n",
|
||||
"('Duration:', 0.6183009147644043)\n",
|
||||
"('Duration:', 0.6176400184631348)\n",
|
||||
"('Duration:', 0.6476900577545166)\n",
|
||||
"('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#@test {\"timeout\": 90}\n",
|
||||
"with tf.Graph().as_default():\n",
|
||||
" hp = tf.contrib.training.HParams(\n",
|
||||
" learning_rate=0.05,\n",
|
||||
" max_steps=max_steps,\n",
|
||||
" )\n",
|
||||
" train_ds = setup_mnist_data(True, hp, 500)\n",
|
||||
" test_ds = setup_mnist_data(False, hp, 100)\n",
|
||||
" tf_train = ag.to_graph(train)\n",
|
||||
" losses = tf_train(train_ds, test_ds, hp)\n",
|
||||
"\n",
|
||||
" with tf.Session() as sess:\n",
|
||||
" durations = []\n",
|
||||
" for t in range(burn_ins + trials):\n",
|
||||
" sess.run(tf.global_variables_initializer())\n",
|
||||
"\n",
|
||||
" start = time.time()\n",
|
||||
" (train_losses, test_losses, train_accuracies,\n",
|
||||
" test_accuracies) = sess.run(losses)\n",
|
||||
"\n",
|
||||
" if t \u003c burn_ins:\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" duration = time.time() - start\n",
|
||||
" durations.append(duration)\n",
|
||||
" print('Duration:', duration)\n",
|
||||
"\n",
|
||||
" print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "A06kdgtZtlce"
|
||||
},
|
||||
"source": [
|
||||
"# Eager"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "hBKOKGrWty4e"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict(m, x, y):\n",
|
||||
" y_p = m(x)\n",
|
||||
" losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n",
|
||||
" l = tf.reduce_mean(losses)\n",
|
||||
" accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
|
||||
" accuracy = tf.reduce_mean(accuracies)\n",
|
||||
" return l, accuracy\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
}
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "HCgTZ0MTt6vt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(ds, hp):\n",
|
||||
" m = mlp_model((28 * 28,))\n",
|
||||
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
|
||||
"\n",
|
||||
" train_losses = []\n",
|
||||
" test_losses = []\n",
|
||||
" train_accuracies = []\n",
|
||||
" test_accuracies = []\n",
|
||||
"\n",
|
||||
" i = 0\n",
|
||||
" train_test_itr = tfe.Iterator(ds)\n",
|
||||
" for (train_x, train_y), (test_x, test_y) in train_test_itr:\n",
|
||||
" train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n",
|
||||
" train_y = tf.one_hot(tf.squeeze(train_y), 10)\n",
|
||||
" test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n",
|
||||
" test_y = tf.one_hot(tf.squeeze(test_y), 10)\n",
|
||||
"\n",
|
||||
" if i \u003e hp.max_steps:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" with tf.GradientTape() as tape:\n",
|
||||
" step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n",
|
||||
" grad = tape.gradient(step_train_loss, m.variables)\n",
|
||||
" opt.apply_gradients(zip(grad, m.variables))\n",
|
||||
" step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
|
||||
"\n",
|
||||
" train_losses.append(step_train_loss)\n",
|
||||
" test_losses.append(step_test_loss)\n",
|
||||
" train_accuracies.append(step_train_accuracy)\n",
|
||||
" test_accuracies.append(step_test_accuracy)\n",
|
||||
"\n",
|
||||
" i += 1\n",
|
||||
" return train_losses, test_losses, train_accuracies, test_accuracies\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"autoexec": {
|
||||
"startup": false,
|
||||
"wait_interval": 0
|
||||
},
|
||||
"height": 215
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 52499,
|
||||
"status": "ok",
|
||||
"timestamp": 1531752103279,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "plv_yrn_t8Dy",
|
||||
"outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"('Duration:', 3.9973549842834473)\n",
|
||||
"('Duration:', 4.018772125244141)\n",
|
||||
"('Duration:', 3.9740989208221436)\n",
|
||||
"('Duration:', 3.9922947883605957)\n",
|
||||
"('Duration:', 3.9795801639556885)\n",
|
||||
"('Duration:', 3.966722011566162)\n",
|
||||
"('Duration:', 3.986541986465454)\n",
|
||||
"('Duration:', 3.992305040359497)\n",
|
||||
"('Duration:', 4.012261867523193)\n",
|
||||
"('Duration:', 4.004716157913208)\n",
|
||||
"('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#@test {\"timeout\": 90}\n",
|
||||
"with context.eager_mode():\n",
|
||||
" durations = []\n",
|
||||
" for t in range(burn_ins + trials):\n",
|
||||
" hp = tf.contrib.training.HParams(\n",
|
||||
" learning_rate=0.05,\n",
|
||||
" max_steps=max_steps,\n",
|
||||
" )\n",
|
||||
" train_ds = setup_mnist_data(True, hp, 500)\n",
|
||||
" test_ds = setup_mnist_data(False, hp, 100)\n",
|
||||
" ds = tf.data.Dataset.zip((train_ds, test_ds))\n",
|
||||
" start = time.time()\n",
|
||||
" (train_losses, test_losses, train_accuracies,\n",
|
||||
" test_accuracies) = train(ds, hp)\n",
|
||||
" \n",
|
||||
" train_losses[-1].numpy()\n",
|
||||
" test_losses[-1].numpy()\n",
|
||||
" train_accuracies[-1].numpy()\n",
|
||||
" test_accuracies[-1].numpy()\n",
|
||||
"\n",
|
||||
" if t \u003c burn_ins:\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" duration = time.time() - start\n",
|
||||
" durations.append(duration)\n",
|
||||
" print('Duration:', duration)\n",
|
||||
"\n",
|
||||
" print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [
|
||||
"eqOvRhOz8SWs",
|
||||
"PZWxEJFM9A7b",
|
||||
"kZV_3pGy8033"
|
||||
],
|
||||
"default_view": {},
|
||||
"name": "Autograph vs. Eager MNIST speed test",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
|
||||
"timestamp": 1530297010607
|
||||
},
|
||||
{
|
||||
"file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9",
|
||||
"timestamp": 1530289467317
|
||||
},
|
||||
{
|
||||
"file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
|
||||
"timestamp": 1522272821237
|
||||
},
|
||||
{
|
||||
"file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
|
||||
"timestamp": 1522238054357
|
||||
},
|
||||
{
|
||||
"file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
|
||||
"timestamp": 1521743157199
|
||||
},
|
||||
{
|
||||
"file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
|
||||
"timestamp": 1520522344607
|
||||
}
|
||||
],
|
||||
"version": "0.3.2",
|
||||
"views": {}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@ -1,935 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "b9R-4ezU3NH0"
|
||||
},
|
||||
"source": [
|
||||
"## AutoGraph: examples of simple algorithms\n",
|
||||
"\n",
|
||||
"This notebook shows how you can use AutoGraph to compile simple algorithms and run them in TensorFlow.\n",
|
||||
"\n",
|
||||
"It requires the nightly build of TensorFlow, which is installed below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "TuWj26KWz1fZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -U -q tf-nightly-2.0-preview"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Cp7iTarmz62Y"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"tf = tf.compat.v2\n",
|
||||
"tf.enable_v2_behavior()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "3kudk1elq0Gh"
|
||||
},
|
||||
"source": [
|
||||
"### Fibonacci numbers\n",
|
||||
"\n",
|
||||
"https://en.wikipedia.org/wiki/Fibonacci_number\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 187
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 709,
|
||||
"status": "ok",
|
||||
"timestamp": 1563825398552,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "H7olFlMXqrHe",
|
||||
"outputId": "25243e7b-99a7-4a6d-ad00-e97c52be7d97"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 : 1\n",
|
||||
"1 : 2\n",
|
||||
"2 : 3\n",
|
||||
"3 : 5\n",
|
||||
"4 : 8\n",
|
||||
"5 : 13\n",
|
||||
"6 : 21\n",
|
||||
"7 : 34\n",
|
||||
"8 : 55\n",
|
||||
"9 : 89\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@tf.function\n",
|
||||
"def fib(n):\n",
|
||||
" f1 = 0\n",
|
||||
" f2 = 1\n",
|
||||
" for i in tf.range(n):\n",
|
||||
" tmp = f2\n",
|
||||
" f2 = f2 + f1\n",
|
||||
" f1 = tmp\n",
|
||||
" tf.print(i, ': ', f2)\n",
|
||||
" return f2\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"_ = fib(tf.constant(10))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "p8zZyj-tq4K3"
|
||||
},
|
||||
"source": [
|
||||
"#### Generated code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "UeWjK8rHq6Cj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(tf.autograph.to_code(fib.python_function))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "eIfVy6ZTrFEH"
|
||||
},
|
||||
"source": [
|
||||
"### Fizz Buzz\n",
|
||||
"\n",
|
||||
"https://en.wikipedia.org/wiki/Fizz_buzz"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 119
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 663,
|
||||
"status": "ok",
|
||||
"timestamp": 1563825401385,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "33CAheYsrEQ7",
|
||||
"outputId": "2a88b65d-4fed-4d96-8770-0c68ffece861"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Buzz\n",
|
||||
"11\n",
|
||||
"Fizz\n",
|
||||
"13\n",
|
||||
"14\n",
|
||||
"FizzBuzz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS)\n",
|
||||
"def fizzbuzz(i, n):\n",
|
||||
" while i \u003c n:\n",
|
||||
" msg = ''\n",
|
||||
" if i % 3 == 0:\n",
|
||||
" msg += 'Fizz'\n",
|
||||
" if i % 5 == 0:\n",
|
||||
" msg += 'Buzz'\n",
|
||||
" if msg == '':\n",
|
||||
" msg = tf.as_string(i)\n",
|
||||
" tf.print(msg)\n",
|
||||
" i += 1\n",
|
||||
" return i\n",
|
||||
"\n",
|
||||
"_ = fizzbuzz(tf.constant(10), tf.constant(16))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Lkq3DBGOv3fA"
|
||||
},
|
||||
"source": [
|
||||
"#### Generated code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "bBhFIIaZrxvx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(tf.autograph.to_code(fizzbuzz.python_function))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "BNRtprSvwJgk"
|
||||
},
|
||||
"source": [
|
||||
"### Conway's Game of Life\n",
|
||||
"\n",
|
||||
"https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "r8_0ioEuAI-a"
|
||||
},
|
||||
"source": [
|
||||
"#### Testing boilerplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "7moIlf8VABkl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"NUM_STEPS = 1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "QlEvfIQPAYF5"
|
||||
},
|
||||
"source": [
|
||||
"#### Game of Life for AutoGraph\n",
|
||||
"\n",
|
||||
"Note: the code may take a while to run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "5pCK2qQSAAK4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@test {\"skip\": true} \n",
|
||||
"NUM_STEPS = 75"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "GPZANPdhMagD"
|
||||
},
|
||||
"source": [
|
||||
"Note: This code uses a non-vectorized algorithm, which is quite slow. For 75 steps, it will take a few minutes to run. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"height": 309
|
||||
},
|
||||
"colab_type": "code",
|
||||
"executionInfo": {
|
||||
"elapsed": 147654,
|
||||
"status": "ok",
|
||||
"timestamp": 1563825336196,
|
||||
"user": {
|
||||
"displayName": "",
|
||||
"photoUrl": "",
|
||||
"userId": ""
|
||||
},
|
||||
"user_tz": 240
|
||||
},
|
||||
"id": "hC3qMqryPDHS",
|
||||
"outputId": "56a095a3-28a3-455d-e95e-2c4c9dcd97d2"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\u003cvideo width=\"432\" height=\"288\" controls autoplay loop\u003e\n",
|
||||
" \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAABdAG1kYXQAAAKuBgX//6rcRem9\n",
|
||||
"5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTUyIHIyODU0IGU5YTU5MDMgLSBILjI2NC9NUEVHLTQg\n",
|
||||
"QVZDIGNvZGVjIC0gQ29weWxlZnQgMjAwMy0yMDE3IC0gaHR0cDovL3d3dy52aWRlb2xhbi5vcmcv\n",
|
||||
"eDI2NC5odG1sIC0gb3B0aW9uczogY2FiYWM9MSByZWY9MyBkZWJsb2NrPTE6MDowIGFuYWx5c2U9\n",
|
||||
"MHgzOjB4MTEzIG1lPWhleCBzdWJtZT03IHBzeT0xIHBzeV9yZD0xLjAwOjAuMDAgbWl4ZWRfcmVm\n",
|
||||
"PTEgbWVfcmFuZ2U9MTYgY2hyb21hX21lPTEgdHJlbGxpcz0xIDh4OGRjdD0xIGNxbT0wIGRlYWR6\n",
|
||||
"b25lPTIxLDExIGZhc3RfcHNraXA9MSBjaHJvbWFfcXBfb2Zmc2V0PS0yIHRocmVhZHM9OSBsb29r\n",
|
||||
"YWhlYWRfdGhyZWFkcz0xIHNsaWNlZF90aHJlYWRzPTAgbnI9MCBkZWNpbWF0ZT0xIGludGVybGFj\n",
|
||||
"ZWQ9MCBibHVyYXlfY29tcGF0PTAgY29uc3RyYWluZWRfaW50cmE9MCBiZnJhbWVzPTMgYl9weXJh\n",
|
||||
"bWlkPTIgYl9hZGFwdD0xIGJfYmlhcz0wIGRpcmVjdD0xIHdlaWdodGI9MSBvcGVuX2dvcD0wIHdl\n",
|
||||
"aWdodHA9MiBrZXlpbnQ9MjUwIGtleWludF9taW49MTAgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVz\n",
|
||||
"aD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTIzLjAgcWNvbXA9MC42MCBx\n",
|
||||
"cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAQZZYiE\n",
|
||||
"ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZQ/thx05aw0AAQoAAjZrf0Z7SQAFS\n",
|
||||
"RBmrGveunhOj4JFso/zYXaRjQ18w/5BhxFIRpIkBeRXl9T8OOtGMbM52JtIMXIY7KRr49/IsKi0w\n",
|
||||
"jJUK8Z7XIFmlAjIU+jSbWER5LmeK+6/diSLijDB3co/ebDgChTdnt/smJJAlFMJhzTUcdwoA8NQo\n",
|
||||
"YBnpXwCtHd9MDNyz4x4zrqfgfXAXtVDOuKqK+ZIROmkudESU5HAc84NxG9mIFkHTHpfRFX0vfuvN\n",
|
||||
"v30XneTe8IilYhOJYkyOcVBz9L5D3N5P2RHbPf8d2Ia4qkwGurGLJl8PxjFsKE4dm+f6WYtxh4/M\n",
|
||||
"EbibuuIVHuFVTrhDBdjGsnlvGJ613cHSu4frv4bqhIfOz9nOKI/zhLw9zlvfAkAek0G+jTz8be7+\n",
|
||||
"o/ndntGdno6L1LXJpdgGJYFOyZwDpk3suJqu9FKdCFsjDfQ4s5OYpZkBRm/h6ksvqs/jKOI7H7Eu\n",
|
||||
"JEDtMn0Px1875SS+KLSHaHwtTCNzTTTEE83rjSnRcLH2qekoCAzC/F7u+tWoo8/5q7AU8ZwbFyde\n",
|
||||
"C0AcLGLOTLX2dctD5sMzDYlYtX/lYiEND4SUALBVfbetB5IH67pM/22hp7cM4zkyUfekvXZeKUpq\n",
|
||||
"ihxpjZ/b0GfRGel+eaIkRAMer8l0HHBl4xOpdwEUiGEQqacmsmAKA7/Wn0I4FZAkAeHbrP6JQw8G\n",
|
||||
"T6oLn8jHc2YBwe6YY+t5SuugRFwnijdFTQ2IYMHZ9spzZjJhn/lftFm13UY9ay8CDty2j8dXZfss\n",
|
||||
"pdN3RSB6EMFrirN6yUkoxa8UPGBKHs9MUFO5MnKDgADHT4JhBGInxUASlDV0lsFB0GH9ED4tkRc6\n",
|
||||
"7SnaMmZwf9T2i4a1NSsheM+jHEQWr9fgPDBABuIyToLYLrnVeLXqSC8JMeZigh4GOpQKyiIsG8oa\n",
|
||||
"f6kiBTwG/5RebTqU6O7rrQLj5Wd5YFdqaacUZGByo8AxJ60NHIoQcxeNjsWAj6m8SKd2+g3en70+\n",
|
||||
"zVQW9HkvHI7nnRF3FhwhZYu/LvproEPyWSYykJIx75ojR14WE7oWSjYs0X2AFiwEouayVGii6owJ\n",
|
||||
"gdlCmnN8HoqT5PPnaOWG7mPgq/3meUuz982ZX4+4VMage3Fe0K3cqRdKLTge+gs4pyQbSUIdrgo3\n",
|
||||
"4P4R1ejF0wAW1R8YjLZz6fQUzzzchgNN0t7aa8tlO2yDCmII5BbaYJXJrRvBm8Lb1m7TLILNalgu\n",
|
||||
"RMjYD4Pf/P4iQqWsBEdgB3p334RMzrBfcviq+49N2SRQlYxV0SbSMdybZaH+vxuw+VyvLt3ulEcF\n",
|
||||
"rmBwnxL4kpGATPv8mogAAAMAUMEAAAI7QZokbEEf/rUqgAYz+kaAoYS6oZnCZBWChU49QzRvBVh/\n",
|
||||
"3Pl1tY/3h6ui3wW2qKCfpdwQ1h/uuKhRazpong7+Xsbw5g3mv3E7I0N68sUiey8Dbt0hMUrR6zYj\n",
|
||||
"YtzMQ7gEdgcbbOEgu3H73w44JvEzvgZ4iO4Q2Kwp7BHY2uxxtdUENoG1kHXqnnQawFSCHZ9W6pRZ\n",
|
||||
"ZX580jW/ekv7tzX5SLrr2mknIiIEL/9OqO/hdKRWyIS92L0VbeMgboQPIpdXZEemH8ScfWR641oo\n",
|
||||
"Kb2ZqixayrynX4qeQdDAXvtKdnTPfgTsOJHs6zrnaaKb6SpoCg9ffzFUfiQ1YwLPZpLhwkJ1F58m\n",
|
||||
"QtliSU1LCArOxcL0CdX1xv0PO1XbIga8mvD2ON78HrYIlpd7r9MIJUgGiGlRxLTUITjvxtxjLYBG\n",
|
||||
"TBzSQ2Mqy08Y4xvBh9/AZrWGoBvplKVOooBAXsS/J3OngcAaMApnGniTlEgacIB/4ihqQm9Zync1\n",
|
||||
"WrLEldONGr9K6gbteZcFnK/hoe6B53agN6YwjF+Hm1IYltzK42eiNQbmeo0nT6xx724Sek57Pcpp\n",
|
||||
"/+64lZEYNhMLw61j8cLCmWJLqJ9+OlV3Tu4kvqWM5A7mBmXunK5EElFvFoiaHvfKnFzVKUZHVN47\n",
|
||||
"dwwOu2bQK/GEFcs57H1A4Ddl2JAlJt4ZWgrJx+vzAgyhhcl1LtQgQcd3rX3aPisDf1CYETnay05i\n",
|
||||
"xe8yUL0AVMzI07+lqERP6auGU//nlrslfAAAAS1BnkJ4h38AGAsZbANezx+IWo4Ni9MoMfKTC08P\n",
|
||||
"cqaDTyueeuPLGgFgW9U33kZ+Bw1xhP+VnfaIAfTxYvkb1WNMMRMsh5PjwSMCmaFIlQvFeKZwdgkf\n",
|
||||
"0eHuoCcg/XQXRqCvEyyYU7Kr945fY16Tu/18Zd8NU8RAJRLFspmBVoIPZ/aTPmSIXDq8KOSzL6TG\n",
|
||||
"sWN+V8RKxGwExIfHZpdEvHu1tOeg+pVzKTracfnYiBxxlkuVIyzOz2mFv1LQ72jZQGocAdWS14tD\n",
|
||||
"EtCsmNljiTGQDRggnoajq8kpnFHws9ZMWmcsV4dQvczexFmx4YibNvvMPauj3CH/KK6FXvQFumid\n",
|
||||
"ftiga3Uno6si2epmOuEVTuVQwXsgCmOyejpjAiAjZuUS1zq40WginD1EPNgRAAAAXQGeYXRDfwAh\n",
|
||||
"r6zZu6OyBrfB5mVsAz3QNRRqvrwAcnFznD7NXanOaWlAADNOwlJX/xGmO79sH9XeNRT/FnLuEPBH\n",
|
||||
"1GJhJV/Xt2R0YziQPpgXV9BLMr5IaMaU9R2CpgAAAPgBnmNqQ38AHhCAmS1kGlkSnBkADoOXdXaF\n",
|
||||
"NGZr+Q4fCvQ7bHDsrrZk+gghfDnB3EgAw+hgyCz7QjPCBdm4Oua2VioU2d4nUZ+UABLNnRNNghIa\n",
|
||||
"znH4EU6++iAxhcURNicOGGgil2sQO5YirsL6J7S/TznXYcILcn91E9qrSkdqAKeiqMttbt/NlBlt\n",
|
||||
"zFtTLIQV87eeTgQtRSaGjNkYcjtT9zsSroMxdQkaS/rgzWfPKqioru5///iiFvV7FHhGNapsB8Ep\n",
|
||||
"xA6YqLEIyfxd3iBKiJ3g/96H/WMQrMVl8ykLYh6g9L/mEknpMxDRuX+/d5vuR5TJpN2l4QAAAY9B\n",
|
||||
"mmdJqEFomUwII//+tSqABipnkgGrJGhoF2xhqIGFJgrTiV28TOHP6iMSZwA4LzauSvgcy42/qpKz\n",
|
||||
"PF+GKWIn2EJeWsQWOqhnFWAeu8Qy08RHEYzw2BIfhXKPnsvQ1D45gRUsCZjYq85tliORVeVqHlvt\n",
|
||||
"fzWrMqI5f+favhs74Q/1bo2ebSMVUSFuP3HPqFVDjXrf/wjJSgWTFPNzCZtjDghfnhYgAzPVh4sd\n",
|
||||
"mfpnfQi7UGcAu+X0SPRW+sCzjBKyZsabYXRLvCvcRgXcWHRJnqJZ7DbIL5Ahmra4MUmiAdrDqxi1\n",
|
||||
"yixz8Ge2MnwDKePhHbASj9FgVyabApZmODkYAk9x2eNsu3NC/GWuEsOYUEJXb3NkJ3H0Ehpogb5q\n",
|
||||
"/7IADF2Rk2r94PZTFE6TdqRa+DeKrhf1PoBJxN2bNx2sA7Pci476Sn+ZpPsAPTlXaikJNRAhO4tD\n",
|
||||
"lakPd29Edmfvk34bCqY6rFMuCfUJ3yzCy+VRKB59CtgS68dVzaJO/FxZ2Of18yjXsScM2fL16/kA\n",
|
||||
"AADDQZ6FRREsN/8AHa60qBaQmR4IRAA6Dl3Sc6VtGJbtr5vbN23f25BY5Mbt9ZodJaqeGLgSZDt5\n",
|
||||
"tMt3+exLq/o1or+DyDOaUjfDuI6HO9EMKVIFrK5bBNySwYGQ9ZOLXviohcSZAskgQCT8YbljWqgY\n",
|
||||
"W5O+m+Ip3OoA9JMxAp4EiGRPR1hmuQDeRomyGX7bvvzp+lmhQcgx50Gtf2FsWph71RE5OIfz3vbU\n",
|
||||
"YPJzvstNoHMLjQVN28uexbTk/wUswGjCQ8u5AAABFwGepmpDfwAhvaAbJNR/9ddNI1ZNZPr5vm6q\n",
|
||||
"XTetXH7Eo8GqFltKJbOb+WxFxg1OZ9LY7Pm4G1n+FvJzAc9iMK3kbM6geeeFIdRl75A0UZYsXIff\n",
|
||||
"dQXiQxB/kP/GUeJS/ghHdsFXhovY2ei0jBYXhl7XCQdiM+OxqVpdBNYdLY+vhvtTydDweWAQhmfY\n",
|
||||
"3fYN3w2o0+YtvleCAQNIu+tN7OfSeOifT7EOLQk4YDYkvT1QcI6scYDf1en6ihiP1DSq11Clzx8a\n",
|
||||
"ja6cddGuoMqDaNkxCF1dzf2Jvz1VA4BpWPjukcCUvSBL5Hjn5IenmZHNevhC9Ri5TKMMAK1OUZos\n",
|
||||
"eUJttkHLI36Z4EqqgVQeXc7fMR78LG9GpQAAATJBmqlJqEFsmUwUTBH//rUqgAcd7WUAG1wL+eMP\n",
|
||||
"5NbNjI1PanDtCkQqkSzemsYEjSdqyjDQBhMRhcVkBjrLnQ37QRY6anUo9HtaOXKEvV3Oq3t3zJnU\n",
|
||||
"VnRnO4+DsYDha+hVjf2RQfz8iIHBAMZBzDCidKRjdK++FyTTJT//wjjoyDzrLD81EvvOEfP1hNq1\n",
|
||||
"E7Mf/LNi4VzZp3xaz5k3oYD4Uh8itElOoUglEcP1/ghF2UcJA9hOtkSUpVhA8+T8Ytc1zpVMfYyg\n",
|
||||
"QqbyRa4EvI2+PCgNWtypZmPOW/fUb8LPNYTg5GLhzbOmSjYpenEUzkib0QksNLKbj/E9aHrV1qHX\n",
|
||||
"qXiny+3UUPxYGvj/pDuYRozh1EchMNkv/eHEkrQhTQjnyxDirLtyAwkvICbz8w9UK2AAAAC1AZ7I\n",
|
||||
"akN/ACK9oCBuM4cceanCEEWpV8cuy27lpLcHp0RFJ/onjSEljOG8VqS2Rkf30kIRre+KMlNGVcvp\n",
|
||||
"cL4orO6Yp5KjC/RRBwQz/yE8UKLNeO0Y0FFhQfICXcBtO9ndieTXXlspFHuGf4S6CeBKlAO/lDFn\n",
|
||||
"Bm6rf4RqP1vvLrD8KUBlig+AFH77l/U3BNsHxmcjURJ4rz9SBUp3dWhkBmKNCP57UtC9bKnqFyE+\n",
|
||||
"YvACZ+sMCAAAAZlBms1J4QpSZTAgj//+tSqAClE1egBKEwbZY3t792fWy96pbeQQCnoXHta8keYB\n",
|
||||
"6YD4iyrisk5RAGXAP8hftXkqsIp3gIADtqeyulunIxMvA+tHyMYI4mH7Ktx24JQCDLGwr+SW5Lfl\n",
|
||||
"LFzLN5Z5EpfMBtjuN1e5MGJfkKE7RLofReD1fgshPg5Hiu3eNzKNtXPqCUQOQrANHyjLVDHW1On8\n",
|
||||
"GbpMg//3+EW5h//MyUrV8C3bm65GCPAdr+IiAQS5PLqRpJaqPFXYImLzCfEF4IcxGqfKzcnaOGUe\n",
|
||||
"P5zhUa+at6SYruNLfSBlr3+mvyhAAxPUBpQBX3a2ZIbz3QLaxiA/KmUnrCDmuWAQmEAoRWFYDkhB\n",
|
||||
"vSu304LzlIj5BSPPqNvyTdiIsLpzAu+SwxleN8rOU8p84R24aRhgQwchoF64pWQkYvhDlixS1XkC\n",
|
||||
"+1BFsz/ugThqWNrj6DMWcUAmd8tN3JWA8raGQmJpBH1Zjd5483GFE2+DssYAdvIzFktdYvwqJy33\n",
|
||||
"xqAAiKb/jZmChnRmwaKmyp+usNPBAAAA+UGe60U0TDv/ABgTM0cFpiU9S5COo+Eq1a5EDpKRq+6p\n",
|
||||
"lSs4dhBzMdhHGYju3Syu9sir+n5TA4S4EozXRjp4djOH9s6Ebl4mnuRqUkAVVyRRxloLXXdAVwvm\n",
|
||||
"Kw2kt3nH3KtGiXPZtoKRlLMwsYrakek54VGjJMSSK7z2j4bZfzdU5fWILhtGELYhukSGMv6CXtq0\n",
|
||||
"ugZLCx24z5CJjXHZ6aJugoOXVvLE5AMKcYDe/LowGji7OLeFgeB849mfSaUGlnh7jxuhBOU+fRS4\n",
|
||||
"p0ITI4vXzUUR4XVTQrOXBNie8HQwoivm+WRv0nW15Zl5mZ7wAnqm6XldppA1IAAAAMIBnwp0Q38A\n",
|
||||
"Ir2gIG4zgb64sxYLzhi9P+r7lwy6Wa7RRkAjTYM9mY6ueOaRzgw6T2RlVKQ/Wnw9OUPsoB+98v3K\n",
|
||||
"7Ai/8Ku9oiX4fIaC4XxFxl+0lQDznNsd4UfPo3AQh6FoBHug176P/7mBbtXW9HioX3mZhTRXJOlh\n",
|
||||
"Psk7HP1i1klJ4f63KMPuZvFOjkq75Z+u+/aiOQvmn6+lP0r2vSaqs7nxNSGwPqSwNXaUgQz58aD0\n",
|
||||
"pB2v6eKf+Yy3eGu8f7HHrAAAANkBnwxqQ38AH77opN4Quy1TZxAAOg5d0nOlbRa1oa+CUrbGUKO9\n",
|
||||
"s1K1K60LxAZlk8ZQWiHU0UUuQDnHAAyjelIcwOj4NipQdTlRBT+HrLVCVEK5smCT4WEyhlST21vf\n",
|
||||
"pS9QIx6rrJJt1ZwRk3fLMy3lh+GbSU8p/deKiRgvPKu2y5xljT8HokdUfoJBN0b+9AYNdPwZxzfv\n",
|
||||
"wRj3rjB+XbCQdH7rLOmVBWtc7YBBcmnLfJ50Xx9vsPrIGyT/orCu88gDS7Q97WNMWaRoINuEV0SN\n",
|
||||
"7lASQ8YC8xeRAAAByEGbEUmoQWiZTAgj//61KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4gDd4k3m\n",
|
||||
"hfDQA786lCeWa51/s1J2qe/kkvnBjg4L/5tqqnPuWzD5CtqsuCrBZfD9tieYn0V6h2QRjHTgf2S7\n",
|
||||
"KbBJVduRkgXz0DCyLCsDRdQx7ZVeilFNQPYHPpL3dFbV2ZQLhZ15DCVv0ijUbfdtbaCxQWk4hFwi\n",
|
||||
"4Cl7Vcv5eumMKNjbBf29eX+p4vfxRMeLxQVGLH+o2FLpf2SZwh6nFX8ReHwFB2aNAZojees14KLO\n",
|
||||
"dDXVOKLwRfawG/F4iTHLNjIHr9KJ7RMP+ZW2v4UodTEwj2IkfoeugjPYygxsYBEN/HIWo7Lp4BiH\n",
|
||||
"W+sGNW6nzMrLHeZnfPrIXJzjKMZ2dMe3r2TPoxLKTVgPHlFgXbB9gOVEkvjr1YtxEt3sHivjr7TH\n",
|
||||
"zrmzrXSS01xk914HSqt/CnYSKPxa2MF69g9I/BNJSHdHCdNGwRVm5U4w/DYDySkJOTHhPK5xLTdI\n",
|
||||
"6pomON2J7Snu3IFO1cMuZQAgHAwoynkWURtTVoyQbA1o0XW4HcVte0xmLSUrxW27KPhiReLpDIah\n",
|
||||
"P07+6UwIug2Iw2yxWwAAAP1Bny9FESw7/wAZUxOT3tiejYgyJDRrCYHaMUHhX+buBbaoqZ/1iUWs\n",
|
||||
"Jb7slI/imiQ6OnWj09SEskbfc/zlMQQ4SNXZauWfHJ95XYh7wMFGgh1p51IG9qMewyJwQS444Zn2\n",
|
||||
"viLgUg5+yrpXHCf0t8/9jDlbqwjDulbT62pdxpAyxuynsO8RFT3dUKeSE5htp/jbraDowEdpXZyE\n",
|
||||
"hG0WYkl+RbztI/PQNZCwZsz+nvpxvKr5XHM1hBpXHcYTolc3yg25EknXG5iovx0Y9EuSqthrt+Xw\n",
|
||||
"mK43mYVJUVC/Oh8GeZYMuS8/kSjScKjb9J2cbfyAxgmK23G/LX345QQtAAAA2AGfTnRDfwAc/TTk\n",
|
||||
"s3FNYSmNHdPgDfXQC1GBEwJGCqSU6MsmeFhDrrArJ4DXkS7h5Olwl5LsAdAjNSMWnsyuwfwlhiS4\n",
|
||||
"Iu9nXiMR2gsFQTdJfxAGWv/oGKrfOpY9OM+oH5mmAEYRbo0uYIZjYyyv9H1tg0RX725ktocEeT9I\n",
|
||||
"3B3Tp4qYCOAxN7JPiw1LGqnL098ntFu5ng1+yPoA7ayjGtnhqUNzDdxHw06qdCQZykRFXaAS2mFv\n",
|
||||
"lmomA2wH7gnlU4hH+9/QtYxMog0PKOypGE94HJSUfoT7gAAAAEEBn1BqQ38AHE7WHA5VnN1RP/m4\n",
|
||||
"B17wBGTsyVXKs9N7WlI9AxsJJ7v9zVkMjf6pvv+Cg6JoQ3BLOK7r3bcONYUtZQAAAddBm1VJqEFs\n",
|
||||
"mUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWKggEMwd90EmHsgCWksYKFE4Qru8Yv50\n",
|
||||
"LqOKJvWMLHGzKIf1mWoops1hD8q4hCLJMEdRItKEcO/AvOw75DCgogAQMHz94YdBlV1FB7/3PGw/\n",
|
||||
"kvp11c7Zd3bjgbTV5f9wCrj5V98Wrk1QkXKTao3xn1WeAORpyCtFJo3KIIzvry0ktsvXmShsZdHK\n",
|
||||
"SF2Q6qY6Id0i1QRrrPRdF2iq2m2rhv1eY7FLgTuR+kimJsshiQFr/qQ4tOO2msQRBI4huY4JSA+L\n",
|
||||
"KftHgweMeBwJfCg9ocoILqar/ZxuCC1Kx59hrQRJPfm8amRIkwU/k+wKJNYh9fLLSBsxlrg4XoMn\n",
|
||||
"PzXBXS36HS/Vq/PUU0Saj0Ks8oGCHCVcz3eoIxgiU+QJY/DixHlF4+MYR1JrL+dYLi5XU6rOa8uy\n",
|
||||
"cymZbC8fCrT8nFmCuYcD3DNSzmKt2Ypk8ahqcNxMHCCE377w4QcAAK8hLicCDiuo9KVio6ugqDQM\n",
|
||||
"DiWya9QmBn0ClIbSCznyVdfSZyODo1gjrJ9IiCMcnWI45hcgB0F/w3f4fUDX3TFD/vbMoTmxwMKV\n",
|
||||
"hWEq4XvI4IEAAAE5QZ9zRRUsO/8AFKVUcHl/E43Gt6o4RZvBs+iAp/X/n7d7Pz7RdmO0J7CPEDVr\n",
|
||||
"YOGCwg4aa5sRnK1DwPx5sIYzP38566ezpK1+yb8tpnK38Otysb+fPORXq89pSQ+5zLmadq08PRPq\n",
|
||||
"ft5b+CuHdsaohxgMdfr5HBiNNodd0VK8TNpXmgIXzYR5RpK7ScM1kMS9Nv/EnJHMV/HrvGwgTDTj\n",
|
||||
"k64XWbP6seQRZKb98opQD+okWzwHsAFj5ehr/ekl0IlB4NOOkEs2vqjJoc0vIcwkba8FSFkLe2wm\n",
|
||||
"HNG8c/q9E5Tipy3avrHlLTvT0bjPkjeD4HLfC3isImW2RvjzyyF2TiLuxINvE8y7u04RbyNnhNhC\n",
|
||||
"J15BQDsVja0XtFDfnnr/h18foOkLRpLJ1yQTMBboYsOrVzSZ9GDWwAAAAM0Bn5J0Q38AHQXz6rvN\n",
|
||||
"uarixND043ZCNdAAIHUCWbOjp5TUpZdEciERk/s2Hj36k/1QHuy5AO7bU6FcTtkwLNXpp4kEhhr2\n",
|
||||
"pj14tuqcy7uq8XfveV+qzHFw516IWJuk3fnleTKVnyg4EmdGVkh8uUm8KAFIin8/UzurGkP5FXB1\n",
|
||||
"JS0uIqtx2mbD94hCpeHMsXHXmWbW3GUD6bwQzUCwUdgGFWWOBIzHIH3jzzxIIZ0rnTzx6fd8zSRM\n",
|
||||
"hMrhmhy9AElVESMBSl9RUVwHxFBAAAABSgGflGpDfwAhvaB1qIOto5yaJpOYSSkbksLCkPuZStd4\n",
|
||||
"LeT7CV/DcB+jLm/y8AhlFfeod4crFEXxelJR/fWiWC5cEAQJB3xoICKkbqYOm6EmFwfhOJrnHL3F\n",
|
||||
"i7egoJ4YJywxTcfWExKLj/7q5Qta5s9pQnji3v49xEhquy1bNbsP/0r8degDcM/eCvveCCuWJP4W\n",
|
||||
"kmgZOsTL6w2RcANA9FiGFsZYFgwwIJNSoi5uPhHUWhw8DgpZUJJwhbcwAlrJ/XkpDgMQdv8+KTaK\n",
|
||||
"5RNrXWUI+DQboZuQqh0EP6Ucm1iy8BiBubHVtPfvfM6aTMlQH2sGDo7kxk+QnIaS5zzgTFrv32D9\n",
|
||||
"yKVtBoqoPJ0AuZgM4FsUTuUjy7Mb8fU+FNoSPESiOFS3CYbvMWBzWtiplx16c8G+2sTGiL+yia5h\n",
|
||||
"U5UjqF9tl+DCrXkPmQAAAhVBm5lJqEFsmUwII//+tSqABlvipo+ln6jP3YEZZAIeN2gdAdBG93Am\n",
|
||||
"88+PBAP+pBG1b08i0fIFrYTfZkz4SYTuxIQ1JlthBpef+blJppNwqif1piWVs/t6bCj9Z+mNxSeq\n",
|
||||
"fY1/wgLfvSZhz+cH951YQ+3lZMxDj+AnlpOYgaA5ONYw7fbC4eXvAp07e1QLTwt7AKsxs6j/dp/S\n",
|
||||
"ROqifCEiS8aS31tyrNd0WUbq8QssOlpj1+9+m64Uuc7+f7EFYNlp0SQRRU2ux+5kBFuUthOQf/99\n",
|
||||
"ODAIvGEvExgFy7U9xycg96i+XWorpOkUsmc8UuZbMVhIEf4MYVuxmTzjhiOVDlxwcksj2gNb3xa2\n",
|
||||
"pmXlh1zp/jlUP6lnJbCcR5jJhGaBJ/wuH3P+rOiJDpAwjSIE4agxxO9XGnmQRqhYjiBkbby/Qs/C\n",
|
||||
"0p6IlpvwhBITpwXRBm1mH+MtJEskEccmYaNT1YNO6b966q1ndwWmG4wqG8yXMOLAMIGnxTjTIpRG\n",
|
||||
"9a5Z9Xdl+HR4ndQhvFfQ+mQNsGUdDPAaOtDr9NfsDESdrHz/VFsWMxlbozv6ME9/FBsTE8SLTZxK\n",
|
||||
"uKA7LtdEmFdsikvrVwkDRWs6mlddIWSLEJey878D400I9Bm2F1YzYF8hIer8urpKTRWH3dl5Pnql\n",
|
||||
"OkpPyvm3RplNwN8DaGYvFB3ajEHHx79ej7jTTF7j2dZAVPOuzAAAAQNBn7dFFSw7/wAYtYg8t2YJ\n",
|
||||
"aBl5mT7LoVquTMWPsAY8JEk7n2Ltj2VU9Y6yhnUjGblNmyV5I1tDP1WCa31R20KBx8ZAPYjEjgAl\n",
|
||||
"IBPsF6gwEF1mGQPgwIt+DQ7Ltrn+WWljoOZe6qmL3ODaEJKUCy9wZy8Qi5WMsDYzpEybVU1vipuE\n",
|
||||
"rsjD5epFom/S3CRpP+JRc2SuBGV9X135AtKz2dAbEFqb0f/DUfvRpyE/xar90tpMsUisBmDyfPqC\n",
|
||||
"QCIWsyVA62u0XX4SHuuo3VkmdASLaLWJS0hWsThucD2h8t0xx4j3t8tQeFkAoX+vhWm72BA6IAOh\n",
|
||||
"cP5AynBLYvgLjkBSaw6ZAAABWgGf1nRDfwAgt5i6arm7oDsF+i9EHiOJ6m6rVkYAHTQbG9yseMuo\n",
|
||||
"2+jJx58xpeovc881Wv+6nIPwZiRTONb2IQaBwPwYP/UAnKjoweUWtNn8yjj61Yi1F5n9oYReT9vo\n",
|
||||
"YNykd6+UIhqXBR69VB8JEqms6DNcB++Z+7S8cRY1PTjUFRAm3tXpZtcqOC46Yje8Z3mZdWtke57d\n",
|
||||
"wfIWf/bjH+PQoHPWtMGigrlGqEUElC6TETXz+nB7X3pF40yVazdjxa5pCPS8j1Bqo/RmILtftGxN\n",
|
||||
"Yu+1c8QTzG5+3qHYIB5lZeEW8bNhQmHlV1zck8pKhAWM+UMUo8Yo1gMDIjGuUuNGCTYOoVand7oO\n",
|
||||
"JxBESUm+840sI50gEtqO5mhNaTQVfGrhYgQvynil8I63rBmEOncCHtkN57Vx9gduQDjk6aOyO6bY\n",
|
||||
"qsBt2jiwg3SW9pmMOjEKBDS6IfMiAxcAAAD/AZ/YakN/ACK6K1xrl4Eswd4/m5m3eDoe6aKYRGzt\n",
|
||||
"qScyJrEz0/YMsioeM46osJc2N8un8CXkVjpps6zgsf8LlkG70ab3ccrB+um/wXzisesiYCwJDgAm\n",
|
||||
"D8ODYrLA2f4XQyaEvxMLwdPggFdV9SLGW7IaDs1Gj2MKL95CD69ggFd4PlXdr+MMXaKnRfCfYej6\n",
|
||||
"jyRkJ6YHIJryGsscniQRwJ0d+J+1KTOriJZQomY6moOkqhpxON7UIyt9lzU6HlHOyQJ+oRH5iOIM\n",
|
||||
"+hKNz7H8znQxxv6dKCBY67rZbPlwYKywoLx2OIjAEQohlh7LdbGhKMy/zzEiJYFobhp2mH1gAAAB\n",
|
||||
"WkGb3UmoQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2Q4cnB3oAeksXBYLQl6DX56J1l\n",
|
||||
"w/mHq8WxaGt2MnAvQ41YNYO39iE6FvpuFKpW712yS65PLr83LJiqo7HZlMfRzKZN59Hb83g9Yzjb\n",
|
||||
"LItfty44d54BI12++V5xh28HT7V7r0Y3bFC5OovybNWx1HQWDmvmM+uWQT6BKmA1pblkm0jWUuJ0\n",
|
||||
"KAyepKH6sPnyIzz9TF/cTcVBDLcJ0ebq4QoNf0i/efDFq1nH+LtoZFDiLpeCwZkCLTOE+JMjcVxC\n",
|
||||
"aWP/XfyRHhNANFDKtoVePLPasXuBVFa5xCh3bB99SWFmaQdxLlk9zHTMNOyCWoiRa9OkdBShrOe1\n",
|
||||
"dfGrU6t4YEao5nNo7umRhNJMptOYWcUtCbSBQmV/4G3c/zgmpJb1N+5bNROg3nNApsFhNWPnDxXX\n",
|
||||
"YEcAkKEAAADvQZ/7RRUsO/8AGBSepWN8xnNsxE4oE6H3s58lr1m+iqw+EfUFRD+Jna0+Uvzz41Eu\n",
|
||||
"ATVBokoBIC1dZOqsBeTj8Ij9FIuxNitjsFqDL+DuZwvmGihDa0HIS79MTSVw/f89Ulk3p2M2jbij\n",
|
||||
"TpCkIItiAXbWCZspatvMx2+GoOmu0/Pjqc6iwrXWXyi9/N9Jj+yY/ClUEyj7sTv82Y9nVf++GCrf\n",
|
||||
"1w5ltOrH9rRQKpUQaVxp4gxcgxC4qFFOgMxs83r/WkZSqY9kO/9UmmCqExD/ljnRMUJvxp8FxL1d\n",
|
||||
"H7PGv4WLI5AeltB+MOGIOr9NYMAAAADwAZ4adEN/ACG6NY+qIzQfcYKCb0AhP1JJtQboSZcB2Ux6\n",
|
||||
"0kAZypUjTcd/OmJjJuZBZL4W6I8Qwzms0HJLp8KRrHdk5GfU6sWQ2Z+fhfAzgzC1XgPD4QBqkDkc\n",
|
||||
"T0sPX8iasgf4/DARkJP486Pq1cqH5kOYBwnnR907+n/qb/xaeHwouVk6h00s/qlqepq0S1p/xGR/\n",
|
||||
"GdINVBgCemrU+PPAyI+EQBjfU66sma3ahiVaLQtsD7mxr/vZVvwLqa7Chr1J9NZveiHKnAzIMG16\n",
|
||||
"G9Gmkk/8FUHgdrIbZ2heuBDh1KQSBCztE11k+ocodRJkiMj5AAABBQGeHGpDfwAhujWPq8KUOIXq\n",
|
||||
"Yi8pfsfzwlVQDEG6igccpABq5mcqZlBxZf6f05WsPP5oiGUHFHfSykAR60y9PVPsKziKYov/dHwR\n",
|
||||
"Kft2Arvz4qT56TCewQ06i1++DP3k7arAvxqk9+C83xiDX/XWrTHQ1+jT9fNei76g+LJLvs+Z4UVk\n",
|
||||
"oEaQ3c6fXvOR9+Md7sWQeZnYPXpC/0w6s38iG8bM/+n0jsTdTFeBwE6YfrCAsv/ybSEXYS5eoPM3\n",
|
||||
"f/HRzfWrUb9MZw2WEuoxs0K4qVyNiDTxcyb1DdadbkuzwkaFG7T2ZM6Pebp0YyXRqckmxx6YTGzB\n",
|
||||
"LlKwKmWHeooj6Lm9LlzVgQAAAaFBmh9JqEFsmUwUTBH//rUqgAYrWZggqZs1s6MH6FUT684nhne8\n",
|
||||
"ykZKf89h+0voVegpTcVlgsFoS6xwNTcMDCv9PiwISM3bG5gmdpPxwsd2af4u9VMbVGyE78HSQ5M/\n",
|
||||
"nbkySYm5CPjed6c1fzFNEjUv+hlxYNfv3cPYnGT/Yav/5erFhxatniKB++1xw2wwwm3hwteUjAt3\n",
|
||||
"Bi79ySg16ijYqJM5fa8+vosVJZysXRlnbW7/ITdmkkl3c8ndruo8FzJ7m8m8z0kOYciXI4QIL6Xh\n",
|
||||
"qroOcvOVcWB7Uug78ZH3AowGQXzMbzVMrLD5Q7gJi2vHbYwWBG8EpVzYFtaj2m+v5trtiq/wJKtt\n",
|
||||
"WosqXvVBFnxrWYQFjXg41D/ASyQHPzn2WsqemfWG6/EDepgeax6MAFQfxyDScuq3fNmr8jf0net2\n",
|
||||
"tjnK9AbUeZfaZDCLHpnptMZuk8clMx5Y+UVSA4sRK6q5yL86vVu3TWQ+TGs9ZFdT4m8kNBPSkwSz\n",
|
||||
"rQpsGSml5JPzqe84pJi6yJhqfYRsb2q5mJ8tkrUntJCF8lR106wAAACuAZ4+akN/AB1RsSI82HuA\n",
|
||||
"EDVZr5mUHFl/p/ZTcmoRWj4TfRvTsYw8OlDJB7dvZ/vcXyur4LGUumPqBQUBQHfGq57+bI/8tRzs\n",
|
||||
"Z+nHU7WH8qJ9BM8/NBixjH12m2oVcRb4XvfrX32V+Y0hU+0j88MNPEcdX4rv7aeeep8jA96PadWJ\n",
|
||||
"mSmtmcZfJIFp4fz7nGsOeHvsRUbV0MKDUYmKN+mrh03bThLfJGXI3U9Tnh+UAAABmUGaI0nhClJl\n",
|
||||
"MCCP//61KoAFm+ceSLbmAtKM+jG0tYuAZBSWLg59auQBOS8BoT1gHMsjZkIU234iG6WAeSbLJEu0\n",
|
||||
"KCLhFA+AqaJQGzw142KKgdSAFtORqvq8YepvegTzCCnS1DU11oB/GUVDtDnboQEryLd0x6NUSSMN\n",
|
||||
"cECL9Mzb9QebAeTbVcgtE4xPKr7FEgVH4vbNIioC6rYN5svm+n7fErwoxd1c4B0MbzpTJ9ypWCIt\n",
|
||||
"jDqP/6ecCXKe8Ac6gqcpyPRaKmFcKdx7byHCFs3Y36UHxsmpasB5iKonQtfou1T7ViPEDD+TNshw\n",
|
||||
"6ncI9FQOyx3EYxNs7CdmXQjjuiQ/hVztgan/8HWeS5jp2zgzBv5BXUEnWn+A7+FBONSn2LL/uQ/w\n",
|
||||
"xRZTcRa0x52ow/V5cvgKu7FATp/RCkX/G+w1Qnp+0VyZbVkCutQ1yOnQYxf79Uw65C1zWPQdQMP/\n",
|
||||
"K+VS6vPAs27IKeqUeSeiBKHv/3isIgE+rjxQbN9Lh1YW9R/9r++mSeHrs60NzUtdlXFG/VIZkaKd\n",
|
||||
"XMkAAADXQZ5BRTRMO/8AFlm8HmElw5CLBq61UEezfOfwLuaBDj371pFQE2TaGfrDL2cPvWN1QZqb\n",
|
||||
"tmH36IVd+buOk4nAS7OK6LGtZWekVP+ro0ezqUL6LNjplSKI15AkcuTQweCsbYhrSLoTsRiawYgs\n",
|
||||
"mv975sfbTCY9L8bxROvDNcwG30R1+JWvK+o/hwf/xA32LhBb08HGKIsZFejSCR/ZACyPMiASYPKQ\n",
|
||||
"KnKHiabUDVxwGq+/saT475SIsPn2KAHPd1oy/JYI5la+DZBAp1lqCWQj4yUkciIB5BAAAABzAZ5g\n",
|
||||
"dEN/AB8V9DqLglnogAnlbAbcaeEM/+Dr1d94BLu23/b924ZA1vKLZ+NWO2PdXQ6go3Sf7NA4nwhe\n",
|
||||
"Jfk07l2+PnIu+kI9sd8bYLUmTTByKGfoyEUnQqTPIf5dfjB+AgnVTc5y8pWcKU354gRsJCt4lQAA\n",
|
||||
"AO0BnmJqQ38AHxX0OouCWHEND0XeNAIAEOFUWlDAA6yKdnA6h0XJ5AHh6k3PwK41LuRgTA6dFitc\n",
|
||||
"eGcLOFImUAXmZeNXd8BBiP4Y7WDb/nj/8t7UR/ChuIYJmbMzvyMcttz9Od2nvufuLeTpnnGxlC5D\n",
|
||||
"sKIQ4TiAF1Zf6Jjc46nP71VK4g2t6fmiQijizaslPXbGXByTezIrwT4YraOsiMH4GMwabs58JhIR\n",
|
||||
"tYealSfNunZO0jU9FNwqBbfEknuQIRSATwmWr49+JU7MtkfWDJ9lAsDVu2W/43LTVqxccM6dY8NC\n",
|
||||
"EBnYMhV6U9uYbKYAAAGwQZpnSahBaJlMCCP//rUqgAZTWZgI3NAzNytjReukCJhCqRIQrgVE5TFG\n",
|
||||
"RpO1ZRhoAw39KCX0FTF/pEpCWlYTREK0RX8M+i/Zkz6IOh5zRR0GMJniH0SeRA8U+ZBIRrL9Hl62\n",
|
||||
"8kZwKv6q5Netv/8gTYt8wrrWIwWANbXHJaruY4G39urxvB/yx7ozBV54M/wmK8P5AgF0ljjPQAUZ\n",
|
||||
"DnLEHwmopi3rWM++lGz+7pSmghGU/3PNF3AxzoRutm1cdRdLqAFKdPRrKeDtflDHW39dHMmsizA0\n",
|
||||
"JAD4HEW4vO3o1CbLX2IxlZFPJGuT1QOtzPR7lO7pJCxfeGJXFchlosXXXbYjZoXRMBBKcHqbIWa+\n",
|
||||
"lcjl1FcSEXbk84/WCNR/hEiDPBQ56Zc4Yg/Uu5te5H7B3WBkQkc5+tttienjQao2TkWT/tLarBIb\n",
|
||||
"fSMA+83k8gbv1oyeFIIWqR6ZYarMVbzfFtnH/fWhWkYB/el6Kk3P0OPSTUOVwdEnhQ/ztu0l8Ij9\n",
|
||||
"PRLg28jDAaygyMt+MtthW/hM1h+aETPrMcrgZoJoV2dKCm8mLdDu/CmksDfLJBRBAAABQkGehUUR\n",
|
||||
"LDv/ABi1i6Ag4bMBZUwXqVJnyx2PYc2F7FCjvy82YHTp5//HJrbZhCcYERymRfl1ah1T5z9noaM6\n",
|
||||
"FqCYiKh/nb1NKcv6lay4yu1An9EGWzEXMRaTXWcwehWRMZky6GX2Elv0mAOhcWIk8WVG2FWKKMhd\n",
|
||||
"27a8KH0mx5CnVDu76Igw2moc1+yPfDPZnRGymeVWDMSj1/TY3hGgb5hmSfANHPp4nyrFETtH62Dy\n",
|
||||
"FIZnfZ2tua96PI/858zqXLfYaSaEy66elRjPHGSUQ+kLj7sT6e2TgQoh23asg1dvl0lw6aW2KtOQ\n",
|
||||
"yQVjdxBZzehiTDj2VDDo/FI5LuGH/jfe71B2giPdfSUEN0GwZPmh+oBJ3YPtBDdEXjvqGtPnj9YN\n",
|
||||
"o2RsGDqkSW3oa8BY1cptmQPEHp1SMBrX83w6xtQW5X0AAAD0AZ6kdEN/ACG9oBtcoOCFYVPj9Yn2\n",
|
||||
"v/zfoFr4rWL2j9A7ZlqQHr0ZVpbLuAQJB33EyTSBNnFvVuljxMl3V6GA7Dl0BClPwL31OrTpG1l7\n",
|
||||
"a7ghzL0atyS5ApCJWtp2wOBNzezTQ3N+Y1tH+luIT/i1PP0KLgniqnzZyMrwKfZeXoYEIl7twi0H\n",
|
||||
"PJVeAcAdd8vPtJ2LywfKZ3u1S3on0S/4f7cj446r85qt7SkU/lr6c/+gK5erYXiPq/kf9oXoMNwY\n",
|
||||
"9h0XgCkkY0ibuAMW3BGf/tJy6AGuO11Q5hQVr9nNkIcjB8Plen8B0nqwKQkOaIEp5QYqYQAAAQkB\n",
|
||||
"nqZqQ38AIr2gIG4zhxx5qcIQ9c2Osw5+uNtUP7c8wH627Nk93kOS5kJwZOUsa/GuB8LSJPcgk4rv\n",
|
||||
"NNy4X5Kv65LRXZpkjxKOzss2V4BAkHf3fdjwk53/8IYs8s8oIvwVKvgR9wljv8Ag07Nf+XJo681q\n",
|
||||
"NbSzOUK6bv18ql/byQhgzEpF9gyeKzBYpIes4Jq5ygJqsHenGCQnuZZGCejK/v7YZig/zrXj2vhG\n",
|
||||
"gCib7VW/rlAZYnZRYtYW6jN8+34R58oAelpNik7qpp/KkHdSQspzMHjVSAa9yHgI/KVEUfAeaSTC\n",
|
||||
"N1Z3u1GIF1TdZRU1zNyC6xbuAxPXtz6Ez91WiAF1zBDEIltBAAABt0Gaq0moQWyZTAgj//61KoAG\n",
|
||||
"e1mYdETW3g4OxfplN37UKMHTaFqDxb+9ytAjpKDc3XnMw/MxT04D0MH+PToJ4KWEuN7AocErZRv2\n",
|
||||
"Rz2GQBbpS8lS31542pk6xM8YYh0/yeF1AnMnBxO2+HilOPhojFg3EW0klIcf/AybMYAo9NSuBD9C\n",
|
||||
"s4e75EU0t8atdvYkg/yfik+FMNyFYTUg/mi4EKL8VgLWVSi8mxQ1+/EWE53/+fwb7K+j+527pMW9\n",
|
||||
"VCj1B/8oEXG8oxyHRw/TQGPoBS7lGz9zLwh8gXusGZBvY9Xy0pnRdJKDkZLO/YjZFLNiCRPsHTqL\n",
|
||||
"i2GYmJ9itG9pRnevDN9cAKQP0fgHBe/nvlXFVK7JMen+RKub1gCuPtFfO/y6rA2fstwepz1bap4Z\n",
|
||||
"wJXzTLHNbeZ6/jnjul1UTQDo+Wyv2+WNy23qAxLYAQV2nquSCySITwJSTVvg+SdePIAmj5UPClGF\n",
|
||||
"OrJIf0RX1xfSrhrpF0W0EhW8ceypgG4+dXb+bPwXKBwbO3GymyW89X2WJwubd13etWWTwju8K204\n",
|
||||
"+w8LWTwxqMyJaP52mExMi4W5Yjr9AyAAAAElQZ7JRRUsO/8AGBMzRwWmJT1LkI6j4SrVrkQOkpGr\n",
|
||||
"7qmVB6agtU/P7NMI3vz5LIs62lee9zlMDhLgStRXRkKeHaPAGaY9hwFwZg4RZnlEijsKiC6r+GA3\n",
|
||||
"jOJMGPR2G+iEvFq9JqYdk0b1d9ABTX/7oiMKav8zTfVNhhkqe32oj6u1ioYXU2U/9Y4cH3f/N9Gx\n",
|
||||
"JhjbFALTGuJMdeB2a/pmxPSRSx2DhwUwXe3BT4iK5IJF2QdQUjRydlTK56i3AOElSAfT6NVqnLr8\n",
|
||||
"mfbO/AiWtC7ZCdSKqLQrBheoCisxuwRDc+0Qj4IlPLBawyneGpiLaece3KMzpKTos+5YxlSYlKtg\n",
|
||||
"/Me6PG+fH2sUI9B09T2Px/9ucFTXTUC5j4ELLv01D5MY2VAAAADfAZ7odEN/ACK9oCBuM4G+uLMW\n",
|
||||
"L2dP1lfTvDhmlpluM7IE4yEUJKicqu4KM5OijIBGmwd/fv/FYUE8C16mNefQ0Uy/D+0+Hpx1ZFAP\n",
|
||||
"3vl+5XYGW/hV3tVz6fpDmClx2VYPTKI+QsHyxc+qQa6raGV2rQAFnERDWDAoPELDpD0DBzrtQ9Gj\n",
|
||||
"f1X0zbjtJNpqrwp/hRbaIrr15pQNp8wHXKVl3vyz9d+FD2rUtkJQVzj6V7XpNVWdz4mpDYH1JRGS\n",
|
||||
"i2MURr0RotwXgP3Qnz/8L/EyxM0Sb/CNWw8xQFPmbCgpDwAAAOUBnupqQ38AH77opN4Quy1TZxAA\n",
|
||||
"Og5d0nOlbRa1c67qPfhIW7P+8Av3GtFE0HFQCvcwO1xKybwlnguY0Nqo5bzwqVZ4m1UebapfH7JG\n",
|
||||
"d9M94gSTzLBzp+7XrhnquJ9dwfh5fBCyLWBt8xSfTcJZr1HXGrAMOw+Jv+pCMMogCsMVlWbHeQuT\n",
|
||||
"mD3/yuQp5lDob+9AYNdyDEIT/fV+2vxg/LuQxTIX08ne1pWMu28zMsHEcHxols+2LTEYzIWCi8BU\n",
|
||||
"K3ZtJRE3rAjZxLOQ4w3m2m/D157HitClmlKcP9jJchoyWV95Jy2gAAABu0Ga70moQWyZTAgj//61\n",
|
||||
"KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4i4wMKqReszSNQj5h17BpSVMT9hX+zPhBrSs6Vj7HyaE\n",
|
||||
"qm6lvw7kPbwwNhW67XEllpB7/AB7Dtmc/Lsrl2N4BzMZzIFVEJCqVkWDwHz0DCyLCsDRdQx8uGEg\n",
|
||||
"Ikolt9wM9AgzvQ7TxR98jTrIYP8SP9CCVhDDASOwwiUKcH0pWRrgAYwjw8Gf7OlbogYj/no1BpFx\n",
|
||||
"lYglvem+TH822s9SIsjJ3EA1IN/sTGSWgAXqwMREDl6rGx1E4un7krghrGWUm+/7j4jDoGqrYrQI\n",
|
||||
"g7E+ktnqOLNELPNyQd8WQ/umSuXC1xL1umwA8X5+yPqMMHEIeQL1fzz/JWAXyMH93QMSzGumbhKw\n",
|
||||
"Zwg0U+25Tvu4PnK5VQHbV0zvOU2Pj+MGf/nsDxqxrqZsD9S4YY9rcTfMxz/MkkzIgfRGQF/OgLHr\n",
|
||||
"joIjF7P6XCeWe+XUgCwqZQG68PRNzfXkn+zUJpMMk0jjnoYnDkQ975Dz0Z65i4o7OdZtwLEOfaoE\n",
|
||||
"pB0fo5td4PyA9vYIFlRo3xi7uvrQcih7/M7KbZFgAAAA9kGfDUUVLDv/ABlUeHLsmGHl+OQZEho1\n",
|
||||
"hMDtEgrgr/N3AttUVM/7crMT5dwlm5uvzGVCn6w/p670sqgr5PJ6oiWC1npINQXp4CRzsctCmXzn\n",
|
||||
"Ugai5K7NbwfaQcfbZKrjzT/10H2u4nhhcuuZyNqUHfbG94mETU3kKDy9A89Il0BA9I1A+R3yjNfc\n",
|
||||
"+Nz5BwP3DN+ZYjka/GHLl0y68JgPyPoe9w8jyG5IXdu2vCa+LYvH9kU234z4psgT4qxlrdkhxxyP\n",
|
||||
"UJXN8nPpx6cXDiQznv0L2owqy0csZbCzUw4CVJ98G+4T1R39bjI9WT0YHLigorskW6Eh4QAAAMYB\n",
|
||||
"nyx0Q38AHP005LNxTWEpiZ1J9di26t3EruDGda0AVBouFN0G1ywEJMXJZuIMxrfHCac7PtwdnQsN\n",
|
||||
"5ABPxruKApfvrd4v1WFO3Cl2Zd1SOG3/r1ORn6HwtueiSFcG0RNU2EL7iLFK3PfYpxwH299J2sER\n",
|
||||
"9fENVpZ0Q3jjs6HsM0edV/QB07Ofn+R5vOS4TYLqhcaZAnuosw5RlS5g1Q8CuW9BZXMHWP4TGLry\n",
|
||||
"nY5Y9ez3m8FrqVUEclyyvuywjGI3odTE+j8AAABPAZ8uakN/ABxO1hwOVZzdUT/5uAde8ARk7MlV\n",
|
||||
"yrPTe1pSPQMTQCpdw5z/lBFmnGZwxWyqh+3IqkDkhpoxeW8ZCVdNB2x/1RnvvpDhcO3MwQAAAbZB\n",
|
||||
"mzNJqEFsmUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWOrIj7MzWLMA+5yFNFLu1hTu\n",
|
||||
"dlbGlkD8jL3ONezhs0gurnHp2pFLsP3djo3BgKHcLr5q4kg5WMX28rT11jnIH4bHAuJDI0/Gub5+\n",
|
||||
"542H8l9OurnbLu7ccDaau7k+AVcLYmIJfjhEaissSRpn2usY/14Z8WeJwbzUwclx5b0pufbMDj2m\n",
|
||||
"E4jonmtfVQvsVKXSLVBGus9F0XUey7wsw1/Hxpa1Dj6X89JFMTZZDEgLc8SXNlb52uC+3SYuA3pO\n",
|
||||
"yIZ3zYRDkwb5/sIpC9s/jtT+DR4JrFHAg/zOLQvdBHh2BZ/H88Qk1FOi1nkBwtogVwTsAvTRwaaM\n",
|
||||
"L+Fy6Vw65xxtt2p06IrGo+vGB6Ev7rBsQ1lA5dJTwIES1/HSnI96cCqyJNRkq8io7XoKHq1jP8jJ\n",
|
||||
"K8KCILcbnjTzWMILhY3EuZ8pRzEGblkg+ofcWDech+PkwDbk4flJvQ1eVGNBBbzkH58MbHNkp5C1\n",
|
||||
"pRDfsnIb9VIwGZIgexRK5GP0EM8ZveKhcNpqg0C7EdFVGM7dDkwAAAFMQZ9RRRUsO/8AFKVU3AQX\n",
|
||||
"TKYCKlUskM896ABcbpuBaq23+VbIBAleYM+Uh2fmC8hKxXufvA+Jyd8ERfcMKq2QBuOeaw8cG8nv\n",
|
||||
"l00dW9FnZ2ewlISmCmZ99L0bw0GXPORXq89pSQ+5zLmGTJWLpbqXg/Gg/k26eFQ7yctp0OrjpANw\n",
|
||||
"gpKfTmSwqfpdIyAO4i1HmWAczC/dxtyvK6EJns7ev/M+uhg/UBsLPdCc4ktjYaoFvgpYJl8v+SaB\n",
|
||||
"iW6/qJFs8B7ABY+Xoa/3pJdDPx7Wo16RIr9F0VKx7gY2CroKhVZyesK3QK039pTJworswqeMoYtQ\n",
|
||||
"SxUGWdIlnZAh/LxAqJSAgdbCea7vV7Jw7UJ3RZWLCaN03DO0g6FTEO0PNlB/y2w2d5hCS2yZtMLR\n",
|
||||
"726poAjDu+5lgVHjodzIR1vHcKS57NpFhydymmBuCPgAAAD3AZ9wdEN/AB0F8+qoYAk/JkWPAABe\n",
|
||||
"eS/K4R2z8W8rEZ4Es2dHO2B1xqZeWERk/2j9D35SD32hnizfkl5AQkKu7sKMRtxB0qUTg/5Ai8ci\n",
|
||||
"ewPsEvh0cTnE+UnVVZQsy2FhpSkguxSgj2GzhV7H4B4oQdASRatW+4ge9XWWDwbNzKDfs2ikSZGn\n",
|
||||
"ZK2J2cdk5ZNdF/NbhHS0c6vDp3S53pob/1OoP8UOX13YMuZJYtnSstfaINj9HWvrLOMusuMgy0ge\n",
|
||||
"hr00WpqM4G4LNFMeeHMWs3VdDioqjp1BlI0pyKTUMl2eH+Urm0ENGx6u7gM90gDkOBdN7tgm4QAA\n",
|
||||
"ASkBn3JqQ38AIb2gdaiDraOcmiaTmEkpG5LCwpD7mwoBhbYx9hK/huA/Rlz76MMOi96iXfBz3DSh\n",
|
||||
"vG5XYVehGnggzBAkHfGgYDsO5F3SWLpvAiWuQYgw379rpdMwhqWoBgIHHe7UqoU3PiKCUX8CUwon\n",
|
||||
"PUuq8JY4AYYztu7mmGelokJyoAJS97RU/X6H+RdsNNzitkC1d8I6jDPIy7qqN4tCnL3rY6Yesfv1\n",
|
||||
"e8kTaN9S190RCoZyxCFd2JzsfgZhniY0nZmfUb/Ilr3HhSfAoNjT9YPJpZU0gCEN/XEjzBiwlPnv\n",
|
||||
"oPqWZP16sXNdepP+5XR/WuewqnrAjpV8x4yn9rFVK/AamriL1xzzEUk66pD3JF3R2TNlp/oPgGf2\n",
|
||||
"3Zht7rWDs3F41xpI2UAAAAHmQZt3SahBbJlMCCP//rUqgAZb4qaPpZ+oz92BGWQCHjdoHQHQRvdw\n",
|
||||
"JuWMeCAf9SCNq3pRzo+QLWwm+zJnwkwndhEvWHQ/SujctvY5pe+lS1QEjQXzeizSF8k6tO14eAtl\n",
|
||||
"F+Mync2FH/YIAKwBXgDqn6AXOHpWQcynHtaJryxWYm270/11pJpJLJP1UcyORiPI54DPlbzdu+l/\n",
|
||||
"jiFd4hpdaoZTSIPUh6A6ClqPxEqekFrNjAxud2WiOSd4IE7Kaf//vpwZ0mh9bmck4Z3rAu3/6Cvy\n",
|
||||
"KA3WyoqAFX4UT0ZjH4z6LrUYRBEZElMEZc4snCHRyZf+tjKnoDXWOrVFpzxu69dV7GJ+V1irRKox\n",
|
||||
"Pd1LRXYUoYi+P14fumR2pYbtX+VBW+m+c7NAd8Z01d3TTKV7Mg7nTZdtCA/oFcETl7++5b2EIheP\n",
|
||||
"k2Fg+5ToPyynpqzSsvv9vWMyfYTJnDg6PojbFsxSs0nRUvqnP5QCdr6QHBhWXFOG60F0RsLzEsNc\n",
|
||||
"wpNcPfKeYjjdCfe8YUIVjq0PBSvcnC+B/ETQWaX7IFbWhPaknWILlx3KsiYwYSMVn5rwfQd4Jkdd\n",
|
||||
"9H+fdht5f/EJHYCK5IGupAjPxHpu+QiB/iUSmCHkkTiMqsG8twzlljjsl22n8veAAAABCEGflUUV\n",
|
||||
"LDv/ABi1iDy3ZgloGXmZPsuhVsylb+qqNi7GSIfQ+OHuoRwObuWCiDJsleSNbQz9VgmS3f493Q1l\n",
|
||||
"fk0LSjQ0QBKQCe3UmCkV8vYYHcKN9CZn1L0i/3IstLHQcy91VMXucG0IQjYMvd5K4nw1TsRQ+zNt\n",
|
||||
"c33OM7wT4gTiFbFnfUP6sORkbyxKD8+9VWHRCKkGnoAnjqhwkHV3YzaNKz290rB0XwxFDvsi8iqf\n",
|
||||
"z+DNrf49LxpvDCniJY8b921MDAhjoaXQisEELwuIkEG2MG16iA+xn4KZIc8cifkUnLKYTAHTEosc\n",
|
||||
"/geFGHZmG9d/0Ad4ehB1+UFj3eeT8gc12jWX2ySdSQAAAUIBn7R0Q38AHbXz6qhgDdTYSzAi1h3K\n",
|
||||
"16Xr3JTVUajJdHP4n1zwK/61yxZ9pP4QSRtJbkJZWH6vivN5vckWYfjVoaQoNcq3qWx+bI+OTtrh\n",
|
||||
"UNznJnNVmMngQpK+748FuR69zyCunCVVntkmuIrtQvOCVbqBuRz5Qxvz7t49H+VL6IAp+Rh2gf74\n",
|
||||
"0j/UPUfosZ/ElbvCMu7rvOP7cWI+JN6KUOE+/AXQCyHGSkSvvSc5FsX0fFal2fQXaEkH67EHfCc5\n",
|
||||
"xhdseiByl+PiqAs8A9zuy4qmXDeeIj+3Yojnw30fZXbmjymzKitBenCylofDP0QjYedpgwNVFWxv\n",
|
||||
"pKDrpf57i5C5JHBxrkMOZNs3TkoKjfQLvKDT/j1Fvw02tHitRU1MR1mnPja0zhtM0e5b68dpKMZ6\n",
|
||||
"9AO+761c+Ba/40Js4HhAAAABBwGftmpDfwAiuitca5eBLMHeP5uZuF9cX0/VXhqHcuiBABGdnZlB\n",
|
||||
"vvbdh+1A3f4uQyVZizhw70/9zDh2nx3tQGn11M/7g3e0ETDcFJMpuy3pyqZj8OhCsFXcJg/Dg2Ky\n",
|
||||
"wNn+F0Nd65xqPmrT4IAWVNyWgNuyHhWrg80hH2qe3n3QFTH+AG0t1LUQWRwdt8cDbAi+8IGZZrTn\n",
|
||||
"QzKAGB5g+jkMrZS2t5af/14Dikh/TUO9x6vp3udUZwfEqX9x43nyKd2KkcrjEt0VxTQ1LHt4TKTU\n",
|
||||
"ov9g2wymXIrIg/m2cGScMEoY8xa4E2v0IBu8Siv364Oh7cF3cjWG+ZJkZ6xGCUsmpmsJt4n9AAAB\n",
|
||||
"cEGbu0moQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2P5yKT5EQsPLolJYuDn1q5ANTN\n",
|
||||
"SJwpmVcvZVK2Tco4v2Comd7hwZPuuXhX+lvh+l6ZtjrC3czf1ZVbdumb3r3D/ioYe7qcFNf7aS5r\n",
|
||||
"2YnlPFx/ox3Po4uR9L227Pa5JPu/JVHojzbyIvC2hUPLYoK3yo8EFTOEx9VW2Kka/dDqBAClQEXM\n",
|
||||
"coaHOVrqvWOBlx0SmrR2Fn5qD0ttjA+wKyG9Ww/+/fxdGsIy8lThxbGnpYEDoqIDxAPPdyC1j/7C\n",
|
||||
"x1S6SZ6cX8TWD+edELbCVScHr4twowGayNRkN1sGJ3ChzFZqefnm592USWq1KVPalCkn+IgAbkI0\n",
|
||||
"gf8crEnxuQcz5L3ov1loEzryk4ptgt40vN/cUUrwi49uNdXDzDlba6ntBbOYIPKYQqVbRsWX//V3\n",
|
||||
"7VjjZzb0fU2VitbTbNlERmPP5obsCvIRmiOfAAAA7EGf2UUVLDv/ABgUnqVjfMZzbMROTbEr98Ov\n",
|
||||
"G6hTv8LwbEOVBTuoZFwTL9eOUuW51yt7Pk5XoOwvCITHjPxM0+ACPLC5p8LXGPLXOMFwxyKNAOm2\n",
|
||||
"+bVnL7eC/eonqWYHV7ElnGiaPE4DZvhksvIAUMvT1hgYsLWg5pHxPTMEf4vPc7k/U4gx+qn0dLIb\n",
|
||||
"xLE6WPqhOli4SJOCHhekKlwgxlnM6S8wIxjTrZQVP6tyjUXc7nRDpn5+4xHTB5JTQd/Y+v5uYYim\n",
|
||||
"vSxL9Lp9+sJa/YqUqQ0UFcQR3Tlp/PCrTJ5gUcQmlTDSjEV8pdpwAAABAgGf+HRDfwAhujWPq7Ze\n",
|
||||
"gCJPvLBRhSSbcG6El3BFXKqbl3V6+XLJCsWmxwO7Xskzh85D3/GGBbxCjXU3okqTeEYfyjkOl+SH\n",
|
||||
"4VGFs6uGeBXI6FuyUdCktochZVIQW+D6bukSQtQ9xBoZWqRH4hlWFBiT6bV+GQGerlgKyeaNsqD5\n",
|
||||
"s+IDfM/wce0dikHUV0++Nr2rHe3jcRRrSy2FHjFSMdnyldmaj1iFauYYGv6d3l/8LPJtc5g5u4Q0\n",
|
||||
"WerxF6DQAN+WlQUAod5dWuqnUKOySujKDQh4Sh1bNoaribkhCngsbjiJUpnyDzJfWcRyF47YB87L\n",
|
||||
"Omkfy8ijCTvweGsJYAgScQAAAQUBn/pqQ38AIbo1j6vClDiF6mIvKX7IDWIXdy1QyeJm7hwAhKrN\n",
|
||||
"5ZQTH6lrtJ9D3xtslHyvy2ywnd5a5/owLJHRc2EtkPadJ8Uji+G9O7CT6ooBM3rAgAWaKgWADHof\n",
|
||||
"Rk55HzZ+V8DMw4S4pnRLudTRFnX1DyLXHV3VXMnhAeP+ewFDtdkUHGMhcSI0U8KajX0wWNdBGeGb\n",
|
||||
"D8Ns9BH8mxfhSu/SqyYkA2AIdaTRVyL0w7XOVFH3DXljVqrcwMdXPvGgiBcw6chMaLbepo7nSmh1\n",
|
||||
"vAbwAQYruBhNTN0eawky0jofbme4HocI40c1sz31wjy2n2/uelK4XikXYFYmVtl4Kdutz8YAAAGb\n",
|
||||
"QZv9SahBbJlMFEwR//61KoAGK1mYIKmbNbOjB+hVE+vOJ4Z3vMpGSn/PYftL6FXoKU3FZYLBaEus\n",
|
||||
"cDU8hX8r/T4sCEjN2tKC+to/+IoDOzT/F3qpjao2Qnfg6SHJn87cmSTE3IR8bzvTmr+Ye4Ac/+hl\n",
|
||||
"xYNmjmRG01XaPV08JLNnbV2zuL5cn/7CsR7I4pKAadGKE6UheVLfqn0i791ThTaaO2OCRjsSWF8e\n",
|
||||
"1o7SXLcWHdmh1WCFSlfjet1S/FkIphxf8M1ZQjLPF96/W7wlOpiP6jEis8o6251YpmdqxS3VSmv/\n",
|
||||
"s9Bv3ISLvkMspiZj+iQwr28MINay/7syEY2A7ZiKqNUJX069yti8CuYwd1gGvQZSlufV+auVaTNU\n",
|
||||
"xocXs0XuFW0e/AWENf2i3yxrLFTHW9CCBeoKH21CafAHq6hi+H/e9DkZU77nSidgvmP6DIx/XjI4\n",
|
||||
"Sp9anaBxYwcylzQtEH2XN+nrwpDPp45KYG9LI0xieadJ2QOTHIvADfNhP/PY2gqE0NQ2qkvQc0a7\n",
|
||||
"Xw6JCi5LfZz745MNAAAA8QGeHGpDfwAdo0DVwAgarNdw1dyEo22Z+2voCmn3MepWOJpNH9uE22Fc\n",
|
||||
"UAf4fo25DS3VGYdH0kZ3bYGxdzd+R7awrh1yiW2ItRU9+fbZ+7eJ43X/1GQK2tLeuYX+rXNnNYVn\n",
|
||||
"3JiyKGKiuk48G4gEpBGTo6LBxeBZg0OXhUHfR3yB3h9X56ir+g4EbNusZoLNQh23BaGzc9/s1PO9\n",
|
||||
"1PPSEqrUiAosSTAygJNCJGqMs5yCqcS+EZopY3ntHhRp/rTMQhL4aAxAb8XQkEJtEmWrzD4p1eX6\n",
|
||||
"QEZh/6hTVX/Gz191R2H/Dtkpg79J3GkssFm0vPkAAAH+QZoBSeEKUmUwII///rUqgAWb5x2D2a6r\n",
|
||||
"t0Z9OpYFG2tABdnWLgsFoSkhKeOGdpZQLTxZJNtdR1o3VEUaCsJe7TDcWLiNBjbFk4iCHCNTwP1B\n",
|
||||
"ET8aIdy/mqBaPrTdtuT/6FMRex7yXV0X/b0t3IdDKZDeFLpQzjHVkdbvbm3BNwCciVQUNcJ7Sjbw\n",
|
||||
"T4hbhPp0oEDMMYqhG0FXqi8cqsDNhwZenV4L974lIjS1k1BRVCVuxIwrhHZ+ZNeKQOVccqtyU7fb\n",
|
||||
"1nmmkdbnAEav9V5tnQTxoYHQvrZLL4f7C+LE0IOtSnKggNbex2Xp0FNi9T/+fjTgmF5bW9OJ+WCx\n",
|
||||
"leyLvNiQF8k0bwSPMh7702+7OB9yXypsT0VFN+3fNlolLg4yJ7ye2ijeDcs0TyR0KI9OqHHwk9VT\n",
|
||||
"lv0R4DjKMuNtxv3yyDdQ02ld84rRe/IbVoqtujoBlwArv27SRkTybmrwQddynU1vfFNgJ2tkTxsX\n",
|
||||
"EuhAyTUDk1pdyrePvO3Kyjq07E+ZdqW1unVDCL0p2PAM0Bdj+ozOm4QJPGRq3YEQjJpnk1BNx6E0\n",
|
||||
"yZMxRvkyW2tYZosgoDR8rW5jEN/sH3PsICgk/jLYhgpsvFfXxjf0NPxMCt81bgYfKxBAoUrGuF/8\n",
|
||||
"Gb453zLMx96NgDfHj/3/yVULmADuEWX3e7X8vwCYAAAA+kGeP0U0TDv/ABZZvB5hJcOQiwautVBH\n",
|
||||
"s3zn8C7pn/fvWkU93yxomewKAdw+9VXghKzj8nMy4EQ6n26QhvOvN3ZOGl4wrl9GlrTzWwgssqXz\n",
|
||||
"oLBd9XVA4LrC7D/kDb3CEAYvcHCWxuhsk3WHFeLlRhwB95RghbDR4boSp+CQz3CY9L8bxC9Ohf/r\n",
|
||||
"dy9+xoLX1H7kyaZJ3YehTdM+5Wu6Hpc4XocPo/ogFns0WlfgVPekkiZdh228q3p+OFEAyCsprsbc\n",
|
||||
"bh4x6zwYau0C11ECccZga0PS18ku4j08dAfMYirHksImmVD9Aw8yto6D9YLwntF8IaA+FPG9VagA\n",
|
||||
"AAB+AZ5edEN/AB8T3aVQEVcYwT0kXXzzDP4yP2lC7bONTcb6acU9HQ87UdrkSLI4+OHKFlU0EAFz\n",
|
||||
"P/GPhcZ5NOIVfnz6vsVd3DH3XZLg43PF1cMypwOcG8sbzfthjMA4FQSgVvJe40X2MhECJet9t2G/\n",
|
||||
"XdWa+YBzkUuLdbRPBeGTAAABAAGeQGpDfwAfDtYNiNYeWLJ1JGi8AHLac8oZrJR5tDRFy80bn36g\n",
|
||||
"01RfxVuWBDFeUQUU4VHoswV2zHbq6MzAloc0SM3f88f/qXApn5tj32GTO8MmdjG+5h2BlZLr7lVk\n",
|
||||
"BcTdEueULRCVgGF4dFB9PX4Y3jYyGQfKH/BWnAEfbs4hEQ8ebrGB8mSRpcKz5q1oNG7pkp8qNfsq\n",
|
||||
"nkhG1h5qVJ826dklpNvhQDQdQnVi0zusZWH7g9GItx1/0euTzo8U/z7D4DrbASMUmgB0DC8TSqJd\n",
|
||||
"xZ+UMAYbubxMdW+iPv2N1tIKXHdcOVBHhDDt1MeY4rBQavQwdjpFZBiUMt5ya+AAAAH0QZpFSahB\n",
|
||||
"aJlMCCH//qpVAAyk+dgiPwCMdRFSufgoxGSIR+/0rSe9Cp9hy8WpEfkfjpu1RSHWd3zlulcFC+Nh\n",
|
||||
"XPR//hjTft5KlTxkfWUjrzSX8Q8sCzZTRHqzVvb/rscPsXHQf0E6taB/yJXWDm9ZR5fbjX3mwQRc\n",
|
||||
"72p/7Nk/lJUO//4LM1qLgtlckFFvGA4aviZYHpBb9w1OJg/Jqwvkkixar7ua0LNG3ane8+4yu/5g\n",
|
||||
"n8krsqxREhrpsaI39b317zkKj6KVaeKiNvQ1KBsts5QsX+yTO1tzmbv5PRxGS8tz2hKf4zB8fbWM\n",
|
||||
"XhqB6Gi2mMVEo6jXnv5vErjT3e551EcovqLpcSnuFBTI4jT6V7ZqZq5zqsmn23ZqFTbBnXJfy5qg\n",
|
||||
"Xc1RIbUSG7SAPcicWIbuNtZ4GQS+WKAEZUxr++6VPQD3gpW4BeKCxEy910wCA11VXaqCgcSgS5FA\n",
|
||||
"dwACIPfrp0NhEyPCvA4qNFC9NitDM1I8HthEGAjfRL6imFuJfW4+Sk08ZcO8JNBK0/bkkNG7XFo7\n",
|
||||
"Hs15nZek/o+FGsRiwki6FYqc1HBc8skTelrrFiYgicL9M/ehriAlP3GGSVQdD58oSyTAbR/XOwHh\n",
|
||||
"/k7736bu5rnUg2SpAi/FdrWUFq0zx+C7UUDgbK+SgABs/nsA2PEAAAE5QZ5jRREsO/8AGLWLoCDh\n",
|
||||
"swFlTBepUmfLHY9h6nZJebQZXCAk5QrW0LEqJOc6Tf3RfmBa+BH+trXpxDsoWsYBGGxFB6vHNSw7\n",
|
||||
"QTuHxSINvJ7kINONdsnA7unyZfe+/dUQpBab4cd9DfyyBJrHeEf61R0Nfn0RkLu3bt6BWIYQlYtM\n",
|
||||
"K9Nfs/vIPwJSfjpXcON5DPtNNDffXZk4RydlgN+S/E7EUmDtA6DaeTT9v6cz5zUd9DSGZ32drbmv\n",
|
||||
"ejyP/MmN69TJZPy1fo/BndGgtSNNbFsKVeTDjxqdcz9cfjIrJ3P86/aSSTu++gY85cN7L+QFkn5k\n",
|
||||
"/lX20+90kKxSs6X+x+u7me+jslyG1ZQaBGKwDx+RwViiPwDARZocg2yGxzRByDsEM59E93SHlUl9\n",
|
||||
"GT+PqBiUfn848MoVbAAAAOoBnoJ0Q38AIb2gG1yg4IVhU+P1ifa//N+gWvitYvaP0DtmWpAevRlW\n",
|
||||
"lsu4BAkHfcTJNIE2cW9WUS6DTP7xEfhthE/Au7/XTkrYH5bPnHuWMD+L4E2Ys7TDv/WnXsb8WMjs\n",
|
||||
"GVKLefmxcZqtW10iMABVusPZiYCVoxR1g16JAWeZ7iIjTKxZ0g1yWUY7SYbSh6LLTrvWvhE7lU5U\n",
|
||||
"CdpswEmIpPdhoFfYojayY1ypJuWbbU1PB5nvwD9t85tVUeFQcQm5aN4kQawNooLXHpvRUW63Gqd8\n",
|
||||
"iY0WiZheEXu2JHmP8XM7t/dfyrk3Fx0AAAEdAZ6EakN/ACK9oCBuM4cceanCEPXT9ZV29ukUDhUK\n",
|
||||
"Q43qY97tIKPQ4ZLk+xSOgxBfQxL7yIrZscfkKmKCSoYxQfZ+tSzvOZ1GhW2ifFuVzAIEg7+77ixc\n",
|
||||
"Kx//CGLPLPJ464HVUHGkhcx37PQ+kbQrXlUbN3cWUp0Qf4LtEibFhZ+LpSZJ4udEDKi6Q/S18Psl\n",
|
||||
"/qmdcccWROb1W4f/Xy9V+lMS0Du/XhxzsIhWccm/rlAZXG9J5NMLdRfS734QHwqLqFpe0KPTU/Mz\n",
|
||||
"iY1ev2MPDzHxs95uiDK6gRc1gvD7TgXhVki57ReTigwP0Vcnsm9mMNHj3Nt6/RMhlMwCLQhy6qqL\n",
|
||||
"YC7Z58RnNbEutfWZAa9Y2SYIcplB+x/e/c7TAAABlkGaiUmoQWyZTAh3//6plgAykehDX8oAigHL\n",
|
||||
"uS7e5BiYpAhLP0Zp72qQ9WFfih2hD6ViubvwAAAy+5vuYY1yi1tJuPfBi/DL0xvClymIwqUp5EK2\n",
|
||||
"pijOf291KPaqRN5kbJjB/2wfKr1+XMiKLX6DysREeFfQlDwQLBucvt+vNOXQokOSOb4yTYfCyIZ/\n",
|
||||
"GHqmX89FI8GoC7SVJ8dqrGOCOpcjHfvSY2QsrqBh9dhAV5Sl9v/BQKeopbgb9Qoepn/uEMh2fyEW\n",
|
||||
"JmX+JgRFJalJclAgIlVBNaF+FoinY0YPKhqMcuoH+rtaEk2LTWu4NHdn9ysTAkHlBR2G+58hU289\n",
|
||||
"8X49s9CJy7d2oeKmsapTwnIxxJ2LNCm+TxMniHit0ZHqI5VMxQ+5ZJ2tPHM7/cT3gdae3yVR8+YM\n",
|
||||
"/KU5H6oISvxSd8TybIcXMyYVHn6O+gwy4SKx3AkMYLFpKRIO1eI3ZmEPll+L/2Ahp3aDBQxulIlY\n",
|
||||
"Qc1v4+BSAHSjYxY/VpZwrkFkWgmuXijX9pnceU+eCQb0BkKKYYEAAAE7QZ6nRRUsO/8AGBMrmVrk\n",
|
||||
"4p6lyIABr6JcUvWGXYV0DKg9NQWqfn9mmEaqxk2L7hAoVLAefd2AT3uOnaK6MhbcdSJ0jbOgAdky\n",
|
||||
"1NCtoTFYEK1L3oNAlJW78V3WE6NttmJ67HTQFhc7jbPt6n2fAdknrF4tehh2ttPPRj0ZMNDck2O/\n",
|
||||
"Og/0bAxzaaL7DSYz/qGCfH6ue/8E9mejEEqzP8HffVv8Obhn2u8eQxOotWj4hO+DblITeYVYJXny\n",
|
||||
"h4Mo9PoOPQCtWY4pEEbVZmokYfc6NrhoTMJC8d+WVfQUp/9dQN2FtoGBhQPHEwvVbIcYhR7B4iO2\n",
|
||||
"lHuM7fr8Nz2PLRQOuR4Lhle59+tgw9IpLSJGfVu5u0NIKILKM/viNoDYYuKxIDdR/J6apnFKAoah\n",
|
||||
"uk9v6if+0v3ru/qsdmBBAAAA3wGexnRDfwAivaAgbjOBvrizFgvOGL0/6vuXDLpZruFaiDwd2rdX\n",
|
||||
"jHVzx9p+aFelpPZGVUpD9afD05Q+ygH73y/cGcCL/wq72iJds0hr5PUpNV/aSoB5zpjnS1krIC0g\n",
|
||||
"xgvcsTNLJd1aFsq1w5umkQK05c9QgDPa1eUOrMmn+/YlpdytXE6u+4FAjIpYVgn74StUYfcT8IT8\n",
|
||||
"SGX5Wru0UB/4BiwZwXDYz0r2pPySvTt1TUg57ubb0S/BqMvEVZ5rArNFw0GaRO5EmmTuHjFK31Ed\n",
|
||||
"ZcrudMiOWUSCfSesj44AAADfAZ7IakN/AB++6KTdC0Gg2vR2G3QAHQcu6TnStota0MGq57eEms8e\n",
|
||||
"GSZ8YTYymFLgl7YZGG1YXmh3orKEBl6b97W6tU9/+wsf9/cg00EpDLAMwmuhlqrl+tcaP161PaCT\n",
|
||||
"db1JjfLZ6rQlIR/u8Lq+hDMPBrZgZ6lFmsHEDUzmL1vhrC/Eg5wjH+dLR3xJpn70Bg13IMQhP99X\n",
|
||||
"7a/GD8u5DFMhlFEykeU8M0AF5LVwxauGljyJ2PG9wt/W7GNjLNgsX4aFTR897+cKWdUMsr13pC8x\n",
|
||||
"KjWMpGHXcQ2lKSkGzAAAAR5BmspJqEFsmUwIb//+p4QAYn+ayCPJyJ7QOf/irXuB3I7yUvrv3Wd8\n",
|
||||
"OLQaJBb/+EMR1r6SAeh0um3VtQPrwYoZU0zDlMzZlECRYSRYOAqgamI/sUVWVEYaYAVab8QpucQ/\n",
|
||||
"sSTh0wVtYsFYYkt/gr7uhkEpx1NPSuJ9CqWeDhMsefol+oaGZkPTooDGiCB29X8Zubhk7s13xY5c\n",
|
||||
"l2KWl6cdQs8QOBu4PKBLJa04v3ctO+FHUCNJTXN7J5YnaOHn+BLPFy7A6HoUxVmuK9kB/hB9j6ln\n",
|
||||
"0nykP3r6vgXJiVxtga3Ek+Zj3edZUHSAUux6bbxkCgdvPWLgxmKM0iIQ0SZS+9McjsqW/5Kw1hL5\n",
|
||||
"sobdDT0GsHJ+I+IDODn9/vmRAAAGqm1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAB1MAAEA\n",
|
||||
"AAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAA\n",
|
||||
"AAAAAAAAAAAAAAAAAAAAAAAAAAIAAAXUdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAA\n",
|
||||
"AB1MAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAGw\n",
|
||||
"AAABIAAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAdTAAACAAAAQAAAAAFTG1kaWEAAAAgbWRo\n",
|
||||
"ZAAAAAAAAAAAAAAAAAAAKAAAASwAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAA\n",
|
||||
"VmlkZW9IYW5kbGVyAAAABPdtaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVm\n",
|
||||
"AAAAAAAAAAEAAAAMdXJsIAAAAAEAAAS3c3RibAAAALNzdHNkAAAAAAAAAAEAAACjYXZjMQAAAAAA\n",
|
||||
"AAABAAAAAAAAAAAAAAAAAAAAAAGwASAASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAA\n",
|
||||
"AAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQAFf/hABhnZAAVrNlBsJaEAAADAAQAAAMAUDxYtlgB\n",
|
||||
"AAZo6+PLIsAAAAAcdXVpZGtoQPJfJE/FujmlG88DI/MAAAAAAAAAGHN0dHMAAAAAAAAAAQAAAEsA\n",
|
||||
"AAQAAAAAFHN0c3MAAAAAAAAAAQAAAAEAAAJgY3R0cwAAAAAAAABKAAAAAQAACAAAAAABAAAUAAAA\n",
|
||||
"AAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABAAAAAAAgAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n",
|
||||
"AQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAAB\n",
|
||||
"AAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEA\n",
|
||||
"AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAwAAAAAAQAA\n",
|
||||
"BAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAA\n",
|
||||
"AAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgA\n",
|
||||
"AAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAA\n",
|
||||
"AAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAA\n",
|
||||
"AAEAAAwAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAA\n",
|
||||
"AQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n",
|
||||
"AAAIAAAAABxzdHNjAAAAAAAAAAEAAAABAAAASwAAAAEAAAFAc3RzegAAAAAAAAAAAAAASwAABs8A\n",
|
||||
"AAI/AAABMQAAAGEAAAD8AAABkwAAAMcAAAEbAAABNgAAALkAAAGdAAAA/QAAAMYAAADdAAABzAAA\n",
|
||||
"AQEAAADcAAAARQAAAdsAAAE9AAAA0QAAAU4AAAIZAAABBwAAAV4AAAEDAAABXgAAAPMAAAD0AAAB\n",
|
||||
"CQAAAaUAAACyAAABnQAAANsAAAB3AAAA8QAAAbQAAAFGAAAA+AAAAQ0AAAG7AAABKQAAAOMAAADp\n",
|
||||
"AAABvwAAAPoAAADKAAAAUwAAAboAAAFQAAAA+wAAAS0AAAHqAAABDAAAAUYAAAELAAABdAAAAPAA\n",
|
||||
"AAEGAAABCQAAAZ8AAAD1AAACAgAAAP4AAACCAAABBAAAAfgAAAE9AAAA7gAAASEAAAGaAAABPwAA\n",
|
||||
"AOMAAADjAAABIgAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEAAABabWV0YQAAAAAAAAAhaGRs\n",
|
||||
"cgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWpdG9vAAAAHWRhdGEAAAABAAAA\n",
|
||||
"AExhdmY1Ny44My4xMDA=\n",
|
||||
"\"\u003e\n",
|
||||
" Your browser does not support the video tag.\n",
|
||||
"\u003c/video\u003e"
|
||||
],
|
||||
"text/plain": [
|
||||
"\u003cIPython.core.display.HTML at 0x7f1286b190b8\u003e"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import traceback\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"from matplotlib import animation as anim\n",
|
||||
"import numpy as np\n",
|
||||
"from IPython import display\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tf.autograph.experimental.do_not_convert\n",
|
||||
"def render(boards):\n",
|
||||
" fig = plt.figure()\n",
|
||||
"\n",
|
||||
" ims = []\n",
|
||||
" for b in boards:\n",
|
||||
" im = plt.imshow(b, interpolation='none')\n",
|
||||
" im.axes.get_xaxis().set_visible(False)\n",
|
||||
" im.axes.get_yaxis().set_visible(False)\n",
|
||||
" ims.append([im])\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" ani = anim.ArtistAnimation(\n",
|
||||
" fig, ims, interval=100, blit=True, repeat_delay=5000)\n",
|
||||
" plt.close()\n",
|
||||
"\n",
|
||||
" display.display(display.HTML(ani.to_html5_video()))\n",
|
||||
" except RuntimeError:\n",
|
||||
" print('Coult not render animation:')\n",
|
||||
" traceback.print_exc()\n",
|
||||
" return 1\n",
|
||||
" return 0\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def gol_episode(board):\n",
|
||||
" new_board = tf.TensorArray(tf.int32, 0, dynamic_size=True)\n",
|
||||
"\n",
|
||||
" for i in tf.range(len(board)):\n",
|
||||
" for j in tf.range(len(board[i])):\n",
|
||||
" num_neighbors = tf.reduce_sum(\n",
|
||||
" board[tf.maximum(i-1, 0):tf.minimum(i+2, len(board)),\n",
|
||||
" tf.maximum(j-1, 0):tf.minimum(j+2, len(board[i]))]\n",
|
||||
" ) - board[i][j]\n",
|
||||
" \n",
|
||||
" if num_neighbors == 2:\n",
|
||||
" new_cell = board[i][j]\n",
|
||||
" elif num_neighbors == 3:\n",
|
||||
" new_cell = 1\n",
|
||||
" else:\n",
|
||||
" new_cell = 0\n",
|
||||
" \n",
|
||||
" new_board.append(new_cell)\n",
|
||||
" final_board = new_board.stack()\n",
|
||||
" final_board = tf.reshape(final_board, board.shape)\n",
|
||||
" return final_board\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"@tf.function(experimental_autograph_options=(\n",
|
||||
" tf.autograph.experimental.Feature.EQUALITY_OPERATORS,\n",
|
||||
" tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS,\n",
|
||||
" tf.autograph.experimental.Feature.LISTS,\n",
|
||||
" ))\n",
|
||||
"def gol(initial_board):\n",
|
||||
" board = initial_board\n",
|
||||
" boards = tf.TensorArray(tf.int32, size=0, dynamic_size=True)\n",
|
||||
"\n",
|
||||
" i = 0\n",
|
||||
" for i in tf.range(NUM_STEPS):\n",
|
||||
" board = gol_episode(board)\n",
|
||||
" boards.append(board)\n",
|
||||
" boards = boards.stack()\n",
|
||||
" tf.py_function(render, (boards,), (tf.int64,))\n",
|
||||
" return i\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"# Gosper glider gun\n",
|
||||
"# Adapted from http://www.cplusplus.com/forum/lounge/75168/\n",
|
||||
"_ = 0\n",
|
||||
"initial_board = tf.constant((\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n",
|
||||
" ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
|
||||
"))\n",
|
||||
"initial_board = tf.pad(initial_board, ((0, 10), (0, 5)))\n",
|
||||
"\n",
|
||||
"_ = gol(initial_board)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "7NgrSPCZxs3h"
|
||||
},
|
||||
"source": [
|
||||
"#### Generated code"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "hIGYeX0Cxs3i"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(tf.autograph.to_code(gol.python_function))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"last_runtime": {
|
||||
"build_target": "",
|
||||
"kind": "local"
|
||||
},
|
||||
"name": "Simple algorithms using AutoGraph",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "19q8KdVF8Cb_fDd13i-WDOG_6n_QGNW5-",
|
||||
"timestamp": 1528465909719
|
||||
}
|
||||
],
|
||||
"version": "0.3.2"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,53 +0,0 @@
|
||||
# Description: Batching scheduling library.
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"py_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "batch_py",
|
||||
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:batch_ops",
|
||||
"//tensorflow/python:batch_ops_gen",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "batch_ops_kernels",
|
||||
deps = [
|
||||
"//tensorflow/core:batch_ops_op_lib",
|
||||
"//tensorflow/core/kernels:batch_kernels",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "batch_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/ops/batch_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
shard_count = 5,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_pip",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":batch_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:script_ops",
|
||||
],
|
||||
)
|
@ -1,27 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops and modules related to batch.
|
||||
|
||||
@@batch_function_v1
|
||||
@@batch_function
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.batching.python.ops.batch_ops import batch_function
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
remove_undocumented(__name__)
|
@ -1,120 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Operations for automatic batching and unbatching."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_batch_ops
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops.batch_ops import batch
|
||||
from tensorflow.python.ops.batch_ops import batch_function
|
||||
from tensorflow.python.ops.batch_ops import unbatch
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
@ops.RegisterGradient("Batch")
|
||||
def _BatchGrad(op, *out_grads): # pylint: disable=invalid-name
|
||||
"""Gradient for batch op."""
|
||||
gradients = []
|
||||
for i in range(len(op.inputs)):
|
||||
gradients.append(
|
||||
gen_batch_ops.unbatch(
|
||||
out_grads[i],
|
||||
op.outputs[-2],
|
||||
op.outputs[-1],
|
||||
timeout_micros=op.get_attr("grad_timeout_micros"),
|
||||
shared_name="batch_gradient_{}_{}".format(op.name, i)))
|
||||
return gradients
|
||||
|
||||
|
||||
@ops.RegisterGradient("Unbatch")
|
||||
def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
|
||||
return [
|
||||
gen_batch_ops.unbatch_grad(
|
||||
op.inputs[0],
|
||||
op.inputs[1],
|
||||
grad,
|
||||
op.inputs[2],
|
||||
shared_name="unbatch_gradient_{}".format(op.name)), None, None
|
||||
]
|
||||
|
||||
|
||||
def batch_function_v1(num_batch_threads,
|
||||
max_batch_size,
|
||||
batch_timeout_micros,
|
||||
allowed_batch_sizes=None,
|
||||
grad_timeout_micros=60 * 1000 * 1000,
|
||||
unbatch_timeout_micros=60 * 1000 * 1000,
|
||||
max_enqueued_batches=10):
|
||||
"""Batches the computation done by the decorated function.
|
||||
|
||||
This is the older version of batch_function(). Please use the former instead
|
||||
of this.
|
||||
|
||||
Args:
|
||||
num_batch_threads: Number of scheduling threads for processing batches
|
||||
of work. Determines the number of batches processed in parallel.
|
||||
max_batch_size: Batch sizes will never be bigger than this.
|
||||
batch_timeout_micros: Maximum number of microseconds to wait before
|
||||
outputting an incomplete batch.
|
||||
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
|
||||
does nothing. Otherwise, supplies a list of batch sizes, causing the op
|
||||
to pad batches up to one of those sizes. The entries must increase
|
||||
monotonically, and the final entry must equal max_batch_size.
|
||||
grad_timeout_micros: The timeout to use for the gradient. See the
|
||||
documentation of the unbatch op for more details. Defaults to 60s.
|
||||
unbatch_timeout_micros: The timeout to use for unbatching. See the
|
||||
documentation of the unbatch op for more details. Defaults to 60s.
|
||||
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
The decorated function will return the unbatched computation output Tensors.
|
||||
"""
|
||||
def decorator(f): # pylint: disable=missing-docstring
|
||||
def decorated(*args):
|
||||
with ops.name_scope("batch") as name:
|
||||
for a in args:
|
||||
if not isinstance(a, ops.Tensor):
|
||||
raise ValueError("All arguments to functions decorated with "
|
||||
"`batch_function` are supposed to be Tensors; "
|
||||
"found %s" % repr(a))
|
||||
batched_tensors, batch_index, id_t = gen_batch_ops.batch(
|
||||
args,
|
||||
num_batch_threads=num_batch_threads,
|
||||
max_batch_size=max_batch_size,
|
||||
batch_timeout_micros=batch_timeout_micros,
|
||||
max_enqueued_batches=max_enqueued_batches,
|
||||
allowed_batch_sizes=allowed_batch_sizes,
|
||||
grad_timeout_micros=grad_timeout_micros,
|
||||
shared_name=name)
|
||||
outputs = f(*batched_tensors)
|
||||
if isinstance(outputs, ops.Tensor):
|
||||
outputs_list = [outputs]
|
||||
else:
|
||||
outputs_list = outputs
|
||||
with ops.name_scope("unbatch") as unbatch_name:
|
||||
unbatched = [
|
||||
gen_batch_ops.unbatch(t, batch_index, id_t,
|
||||
timeout_micros=unbatch_timeout_micros,
|
||||
shared_name=unbatch_name + "/" + t.name)
|
||||
for t in outputs_list]
|
||||
if isinstance(outputs, ops.Tensor):
|
||||
return unbatched[0]
|
||||
return unbatched
|
||||
return decorated
|
||||
return decorator
|
@ -1,87 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for the currently experimental in-graph batch ops."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.batching.python.ops import batch_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def delayed_plus1(x):
|
||||
"""Sleeps for 100ms then returns x+1."""
|
||||
time.sleep(0.1)
|
||||
return x + 1
|
||||
|
||||
|
||||
class BatchOpsTest(test.TestCase):
|
||||
"""Tests for batch_ops.{un,}batch."""
|
||||
|
||||
def testBasicUnbatchV1Decorated(self):
|
||||
"""Tests that the batch_function_v1 decorator works."""
|
||||
with self.cached_session() as sess:
|
||||
|
||||
@batch_ops.batch_function_v1(1, 10, 100000)
|
||||
def computation(in_t):
|
||||
return in_t + 1
|
||||
|
||||
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
|
||||
result = computation(inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([result], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [3])
|
||||
|
||||
def testUnbatchGrad(self):
|
||||
"""Tests that batch and unbatch are differentiable."""
|
||||
with self.cached_session() as sess:
|
||||
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
|
||||
batched, index, id_t = batch_ops.batch(
|
||||
[inp], num_batch_threads=1, max_batch_size=2,
|
||||
batch_timeout_micros=36000000, grad_timeout_micros=1000000,
|
||||
batching_queue="")
|
||||
computation = batched[0] * batched[0]
|
||||
result = batch_ops.unbatch(computation, index, id_t,
|
||||
timeout_micros=1000000, shared_name="unbatch")
|
||||
grad = gradients_impl.gradients(result, inp)
|
||||
thread_results = []
|
||||
|
||||
def worker():
|
||||
thread_results.extend(sess.run([grad], feed_dict={inp: [1]}))
|
||||
|
||||
worker_thread = threading.Thread(target=worker)
|
||||
worker_thread.start()
|
||||
main_results = sess.run([grad], feed_dict={inp: [2]})
|
||||
worker_thread.join()
|
||||
self.assertEqual(thread_results[0], [2])
|
||||
self.assertEqual(main_results[0], [4])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,60 +0,0 @@
|
||||
# Description:
|
||||
# Contains ops for working with statistical distributions,
|
||||
# particularly useful for Bayesian inference.
|
||||
# APIs here are meant to evolve over time.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//learning/brain/contrib/bayesflow:__subpackages__",
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "bayesflow_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:functional_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "monte_carlo_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/monte_carlo_test.py"],
|
||||
additional_deps = [
|
||||
":bayesflow_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/distributions:distributions_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_seed",
|
||||
],
|
||||
)
|
@ -1,17 +0,0 @@
|
||||
# Notice
|
||||
|
||||
`tf.contrib.bayesflow` has moved!
|
||||
|
||||
See new code at [github.com/tensorflow/probability](
|
||||
https://github.com/tensorflow/probability).
|
||||
|
||||
Switch imports with:
|
||||
|
||||
```python
|
||||
# old
|
||||
import tensorflow as tf
|
||||
tfp = tf.contrib.bayesflow
|
||||
|
||||
# new
|
||||
import tensorflow_probability as tfp
|
||||
```
|
@ -1,36 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for representing Bayesian computation.
|
||||
|
||||
Use [tfp](/probability/api_docs/python/tfp) instead.
|
||||
|
||||
## This package provides classes for Bayesian computation with TensorFlow.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
|
||||
_allowed_symbols = [
|
||||
'monte_carlo',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -1,19 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ops module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -1,260 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Monte Carlo Ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib import layers as layers_lib
|
||||
from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
|
||||
from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples
|
||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||
from tensorflow.python.ops.distributions import kullback_leibler
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
layers = layers_lib
|
||||
mc = monte_carlo_lib
|
||||
|
||||
|
||||
class ExpectationImportanceSampleTest(test.TestCase):
|
||||
|
||||
def test_normal_integral_mean_and_var_correctly_estimated(self):
|
||||
n = int(1e6)
|
||||
with self.cached_session():
|
||||
mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
|
||||
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
|
||||
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
|
||||
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
|
||||
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
|
||||
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
|
||||
|
||||
# Compute E_p[X].
|
||||
e_x = mc.expectation_importance_sampler(
|
||||
f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
|
||||
|
||||
# Compute E_p[X^2].
|
||||
e_x2 = mc.expectation_importance_sampler(
|
||||
f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
|
||||
|
||||
stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
|
||||
|
||||
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
|
||||
# pass.
|
||||
# Convergence of mean is +- 0.003 if n = 100M
|
||||
# Convergence of stddev is +- 0.00001 if n = 100M
|
||||
self.assertEqual(p.batch_shape, e_x.get_shape())
|
||||
self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
|
||||
self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
|
||||
|
||||
def test_multivariate_normal_prob_positive_product_of_components(self):
|
||||
# Test that importance sampling can correctly estimate the probability that
|
||||
# the product of components in a MultivariateNormal are > 0.
|
||||
n = 1000
|
||||
with self.cached_session():
|
||||
p = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=[0.], scale_diag=[1.0, 1.0])
|
||||
q = mvn_diag_lib.MultivariateNormalDiag(
|
||||
loc=[0.5], scale_diag=[3., 3.])
|
||||
|
||||
# Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
|
||||
# Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
|
||||
def indicator(x):
|
||||
x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
|
||||
return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)
|
||||
|
||||
prob = mc.expectation_importance_sampler(
|
||||
f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
|
||||
|
||||
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
|
||||
# pass.
|
||||
# Convergence is +- 0.004 if n = 100k.
|
||||
self.assertEqual(p.batch_shape, prob.get_shape())
|
||||
self.assertAllClose(0.5, prob.eval(), rtol=0.05)
|
||||
|
||||
|
||||
class ExpectationImportanceSampleLogspaceTest(test.TestCase):
|
||||
|
||||
def test_normal_distribution_second_moment_estimated_correctly(self):
|
||||
# Test the importance sampled estimate against an analytical result.
|
||||
n = int(1e6)
|
||||
with self.cached_session():
|
||||
mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
|
||||
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
|
||||
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
|
||||
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
|
||||
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
|
||||
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
|
||||
|
||||
# Compute E_p[X^2].
|
||||
# Should equal [1, (2/3)^2]
|
||||
log_e_x2 = mc.expectation_importance_sampler_logspace(
|
||||
log_f=lambda x: math_ops.log(math_ops.square(x)),
|
||||
log_p=p.log_prob,
|
||||
sampling_dist_q=q,
|
||||
n=n,
|
||||
seed=42)
|
||||
e_x2 = math_ops.exp(log_e_x2)
|
||||
|
||||
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
|
||||
# pass.
|
||||
self.assertEqual(p.batch_shape, e_x2.get_shape())
|
||||
self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02)
|
||||
|
||||
|
||||
class GetSamplesTest(test.TestCase):
|
||||
"""Test the private method 'get_samples'."""
|
||||
|
||||
def test_raises_if_both_z_and_n_are_none(self):
|
||||
with self.cached_session():
|
||||
dist = normal_lib.Normal(loc=0., scale=1.)
|
||||
z = None
|
||||
n = None
|
||||
seed = None
|
||||
with self.assertRaisesRegexp(ValueError, 'exactly one'):
|
||||
_get_samples(dist, z, n, seed)
|
||||
|
||||
def test_raises_if_both_z_and_n_are_not_none(self):
|
||||
with self.cached_session():
|
||||
dist = normal_lib.Normal(loc=0., scale=1.)
|
||||
z = dist.sample(seed=42)
|
||||
n = 1
|
||||
seed = None
|
||||
with self.assertRaisesRegexp(ValueError, 'exactly one'):
|
||||
_get_samples(dist, z, n, seed)
|
||||
|
||||
def test_returns_n_samples_if_n_provided(self):
|
||||
with self.cached_session():
|
||||
dist = normal_lib.Normal(loc=0., scale=1.)
|
||||
z = None
|
||||
n = 10
|
||||
seed = None
|
||||
z = _get_samples(dist, z, n, seed)
|
||||
self.assertEqual((10,), z.get_shape())
|
||||
|
||||
def test_returns_z_if_z_provided(self):
|
||||
with self.cached_session():
|
||||
dist = normal_lib.Normal(loc=0., scale=1.)
|
||||
z = dist.sample(10, seed=42)
|
||||
n = None
|
||||
seed = None
|
||||
z = _get_samples(dist, z, n, seed)
|
||||
self.assertEqual((10,), z.get_shape())
|
||||
|
||||
|
||||
class ExpectationTest(test.TestCase):
|
||||
|
||||
def test_works_correctly(self):
|
||||
with self.cached_session() as sess:
|
||||
x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
|
||||
p = normal_lib.Normal(loc=x, scale=1.)
|
||||
|
||||
# We use the prefex "efx" to mean "E_p[f(X)]".
|
||||
f = lambda u: u
|
||||
efx_true = x
|
||||
samples = p.sample(int(1e5), seed=1)
|
||||
efx_reparam = mc.expectation(f, samples, p.log_prob)
|
||||
efx_score = mc.expectation(f, samples, p.log_prob,
|
||||
use_reparametrization=False)
|
||||
|
||||
[
|
||||
efx_true_,
|
||||
efx_reparam_,
|
||||
efx_score_,
|
||||
efx_true_grad_,
|
||||
efx_reparam_grad_,
|
||||
efx_score_grad_,
|
||||
] = sess.run([
|
||||
efx_true,
|
||||
efx_reparam,
|
||||
efx_score,
|
||||
gradients_impl.gradients(efx_true, x)[0],
|
||||
gradients_impl.gradients(efx_reparam, x)[0],
|
||||
gradients_impl.gradients(efx_score, x)[0],
|
||||
])
|
||||
|
||||
self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_)
|
||||
|
||||
self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.)
|
||||
self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.)
|
||||
|
||||
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
|
||||
np.isfinite(efx_reparam_grad_))
|
||||
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
|
||||
np.isfinite(efx_score_grad_))
|
||||
|
||||
self.assertAllClose(efx_true_grad_, efx_reparam_grad_,
|
||||
rtol=0.03, atol=0.)
|
||||
# Variance is too high to be meaningful, so we'll only check those which
|
||||
# converge.
|
||||
self.assertAllClose(efx_true_grad_[2:-2],
|
||||
efx_score_grad_[2:-2],
|
||||
rtol=0.05, atol=0.)
|
||||
|
||||
def test_docstring_example_normal(self):
|
||||
with self.cached_session() as sess:
|
||||
num_draws = int(1e5)
|
||||
mu_p = constant_op.constant(0.)
|
||||
mu_q = constant_op.constant(1.)
|
||||
p = normal_lib.Normal(loc=mu_p, scale=1.)
|
||||
q = normal_lib.Normal(loc=mu_q, scale=2.)
|
||||
exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q)
|
||||
approx_kl_normal_normal = monte_carlo_lib.expectation(
|
||||
f=lambda x: p.log_prob(x) - q.log_prob(x),
|
||||
samples=p.sample(num_draws, seed=42),
|
||||
log_prob=p.log_prob,
|
||||
use_reparametrization=(p.reparameterization_type
|
||||
== distribution_lib.FULLY_REPARAMETERIZED))
|
||||
[exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([
|
||||
exact_kl_normal_normal, approx_kl_normal_normal])
|
||||
self.assertEqual(
|
||||
True,
|
||||
p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
|
||||
self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_,
|
||||
rtol=0.01, atol=0.)
|
||||
|
||||
# Compare gradients. (Not present in `docstring`.)
|
||||
gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0]
|
||||
gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0]
|
||||
[
|
||||
gradp_exact_kl_normal_normal_,
|
||||
gradq_exact_kl_normal_normal_,
|
||||
gradp_approx_kl_normal_normal_,
|
||||
gradq_approx_kl_normal_normal_,
|
||||
] = sess.run([
|
||||
gradp(exact_kl_normal_normal),
|
||||
gradq(exact_kl_normal_normal),
|
||||
gradp(approx_kl_normal_normal),
|
||||
gradq(approx_kl_normal_normal),
|
||||
])
|
||||
self.assertAllClose(gradp_exact_kl_normal_normal_,
|
||||
gradp_approx_kl_normal_normal_,
|
||||
rtol=0.01, atol=0.)
|
||||
self.assertAllClose(gradq_exact_kl_normal_normal_,
|
||||
gradq_approx_kl_normal_normal_,
|
||||
rtol=0.01, atol=0.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,36 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Monte Carlo integration and helpers.
|
||||
|
||||
Use [tfp.monte_carlo](/probability/api_docs/python/tfp/monte_carlo) instead.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'expectation',
|
||||
'expectation_importance_sampler',
|
||||
'expectation_importance_sampler_logspace',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -1,374 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Monte Carlo integration and helpers.
|
||||
|
||||
@@expectation
|
||||
@@expectation_importance_sampler
|
||||
@@expectation_importance_sampler_logspace
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
__all__ = [
|
||||
'expectation',
|
||||
'expectation_importance_sampler',
|
||||
'expectation_importance_sampler_logspace',
|
||||
]
|
||||
|
||||
|
||||
def expectation_importance_sampler(f,
|
||||
log_p,
|
||||
sampling_dist_q,
|
||||
z=None,
|
||||
n=None,
|
||||
seed=None,
|
||||
name='expectation_importance_sampler'):
|
||||
r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\).
|
||||
|
||||
With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns
|
||||
|
||||
\\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\)
|
||||
\\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\)
|
||||
\\(= E_p[f(Z)]\\)
|
||||
|
||||
This integral is done in log-space with max-subtraction to better handle the
|
||||
often extreme values that `f(z) p(z) / q(z)` can take on.
|
||||
|
||||
If `f >= 0`, it is up to 2x more efficient to exponentiate the result of
|
||||
`expectation_importance_sampler_logspace` applied to `Log[f]`.
|
||||
|
||||
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
|
||||
|
||||
Args:
|
||||
f: Callable mapping samples from `sampling_dist_q` to `Tensors` with shape
|
||||
broadcastable to `q.batch_shape`.
|
||||
For example, `f` works "just like" `q.log_prob`.
|
||||
log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
|
||||
shape broadcastable to `q.batch_shape`.
|
||||
For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
|
||||
sampling_dist_q: The sampling distribution.
|
||||
`tfp.distributions.Distribution`.
|
||||
`float64` `dtype` recommended.
|
||||
`log_p` and `q` should be supported on the same set.
|
||||
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
|
||||
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
|
||||
seed: Python integer to seed the random number generator.
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
The importance sampling estimate. `Tensor` with `shape` equal
|
||||
to batch shape of `q`, and `dtype` = `q.dtype`.
|
||||
"""
|
||||
q = sampling_dist_q
|
||||
with ops.name_scope(name, values=[z, n]):
|
||||
z = _get_samples(q, z, n, seed)
|
||||
|
||||
log_p_z = log_p(z)
|
||||
q_log_prob_z = q.log_prob(z)
|
||||
|
||||
def _importance_sampler_positive_f(log_f_z):
|
||||
# Same as expectation_importance_sampler_logspace, but using Tensors
|
||||
# rather than samples and functions. Allows us to sample once.
|
||||
log_values = log_f_z + log_p_z - q_log_prob_z
|
||||
return _logspace_mean(log_values)
|
||||
|
||||
# With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\),
|
||||
# \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\)
|
||||
# \\( = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\)
|
||||
# Without incurring bias, 1 is added to each to prevent zeros in logspace.
|
||||
# The logarithm is approximately linear around 1 + epsilon, so this is good
|
||||
# for small values of 'z' as well.
|
||||
f_z = f(z)
|
||||
log_f_plus_z = math_ops.log(nn.relu(f_z) + 1.)
|
||||
log_f_minus_z = math_ops.log(nn.relu(-1. * f_z) + 1.)
|
||||
|
||||
log_f_plus_integral = _importance_sampler_positive_f(log_f_plus_z)
|
||||
log_f_minus_integral = _importance_sampler_positive_f(log_f_minus_z)
|
||||
|
||||
return math_ops.exp(log_f_plus_integral) - math_ops.exp(log_f_minus_integral)
|
||||
|
||||
|
||||
def expectation_importance_sampler_logspace(
|
||||
log_f,
|
||||
log_p,
|
||||
sampling_dist_q,
|
||||
z=None,
|
||||
n=None,
|
||||
seed=None,
|
||||
name='expectation_importance_sampler_logspace'):
|
||||
r"""Importance sampling with a positive function, in log-space.
|
||||
|
||||
With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\),
|
||||
this `Op` returns
|
||||
|
||||
\\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\)
|
||||
\\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\)
|
||||
\\(= Log[E_p[f(Z)]]\\)
|
||||
|
||||
This integral is done in log-space with max-subtraction to better handle the
|
||||
often extreme values that `f(z) p(z) / q(z)` can take on.
|
||||
|
||||
In contrast to `expectation_importance_sampler`, this `Op` returns values in
|
||||
log-space.
|
||||
|
||||
|
||||
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
|
||||
|
||||
Args:
|
||||
log_f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
|
||||
shape broadcastable to `q.batch_shape`.
|
||||
For example, `log_f` works "just like" `sampling_dist_q.log_prob`.
|
||||
log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
|
||||
shape broadcastable to `q.batch_shape`.
|
||||
For example, `log_p` works "just like" `q.log_prob`.
|
||||
sampling_dist_q: The sampling distribution.
|
||||
`tfp.distributions.Distribution`.
|
||||
`float64` `dtype` recommended.
|
||||
`log_p` and `q` should be supported on the same set.
|
||||
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
|
||||
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
|
||||
seed: Python integer to seed the random number generator.
|
||||
name: A name to give this `Op`.
|
||||
|
||||
Returns:
|
||||
Logarithm of the importance sampling estimate. `Tensor` with `shape` equal
|
||||
to batch shape of `q`, and `dtype` = `q.dtype`.
|
||||
"""
|
||||
q = sampling_dist_q
|
||||
with ops.name_scope(name, values=[z, n]):
|
||||
z = _get_samples(q, z, n, seed)
|
||||
log_values = log_f(z) + log_p(z) - q.log_prob(z)
|
||||
return _logspace_mean(log_values)
|
||||
|
||||
|
||||
def _logspace_mean(log_values):
|
||||
"""Evaluate `Log[E[values]]` in a stable manner.
|
||||
|
||||
Args:
|
||||
log_values: `Tensor` holding `Log[values]`.
|
||||
|
||||
Returns:
|
||||
`Tensor` of same `dtype` as `log_values`, reduced across dim 0.
|
||||
`Log[Mean[values]]`.
|
||||
"""
|
||||
# center = Max[Log[values]], with stop-gradient
|
||||
# The center hopefully keep the exponentiated term small. It is canceled
|
||||
# from the final result, so putting stop gradient on it will not change the
|
||||
# final result. We put stop gradient on to eliminate unnecessary computation.
|
||||
center = array_ops.stop_gradient(_sample_max(log_values))
|
||||
|
||||
# centered_values = exp{Log[values] - E[Log[values]]}
|
||||
centered_values = math_ops.exp(log_values - center)
|
||||
|
||||
# log_mean_of_values = Log[ E[centered_values] ] + center
|
||||
# = Log[ E[exp{log_values - E[log_values]}] ] + center
|
||||
# = Log[E[values]] - E[log_values] + center
|
||||
# = Log[E[values]]
|
||||
log_mean_of_values = math_ops.log(_sample_mean(centered_values)) + center
|
||||
|
||||
return log_mean_of_values
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
'2018-10-01',
|
||||
'The tf.contrib.bayesflow library has moved to '
|
||||
'TensorFlow Probability (https://github.com/tensorflow/probability). '
|
||||
'Use `tfp.monte_carlo.expectation` instead.',
|
||||
warn_once=True)
|
||||
def expectation(f, samples, log_prob=None, use_reparametrization=True,
|
||||
axis=0, keep_dims=False, name=None):
|
||||
r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
|
||||
|
||||
This function computes the Monte-Carlo approximation of an expectation, i.e.,
|
||||
|
||||
\\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\)
|
||||
|
||||
where:
|
||||
|
||||
- `x_j = samples[j, ...]`,
|
||||
- `log(p(samples)) = log_prob(samples)` and
|
||||
- `m = prod(shape(samples)[axis])`.
|
||||
|
||||
Tricks: Reparameterization and Score-Gradient
|
||||
|
||||
When p is "reparameterized", i.e., a diffeomorphic transformation of a
|
||||
parameterless distribution (e.g.,
|
||||
`Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
|
||||
expectation, i.e.,
|
||||
grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where
|
||||
S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\).
|
||||
|
||||
However, if p is not reparameterized, TensorFlow's gradient will be incorrect
|
||||
since the chain-rule stops at samples of non-reparameterized distributions.
|
||||
(The non-differentiated result, `approx_expectation`, is the same regardless
|
||||
of `use_reparametrization`.) In this circumstance using the Score-Gradient
|
||||
trick results in an unbiased gradient, i.e.,
|
||||
|
||||
```none
|
||||
grad[ E_p[f(X)] ]
|
||||
= grad[ int dx p(x) f(x) ]
|
||||
= int dx grad[ p(x) f(x) ]
|
||||
= int dx [ p'(x) f(x) + p(x) f'(x) ]
|
||||
= int dx p(x) [p'(x) / p(x) f(x) + f'(x) ]
|
||||
= int dx p(x) grad[ f(x) p(x) / stop_grad[p(x)] ]
|
||||
= E_p[ grad[ f(x) p(x) / stop_grad[p(x)] ] ]
|
||||
```
|
||||
|
||||
Unless p is not reparametrized, it is usually preferable to
|
||||
`use_reparametrization = True`.
|
||||
|
||||
Warning: users are responsible for verifying `p` is a "reparameterized"
|
||||
distribution.
|
||||
|
||||
Example Use:
|
||||
|
||||
```python
|
||||
import tensorflow_probability as tfp
|
||||
tfd = tfp.distributions
|
||||
|
||||
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
|
||||
|
||||
num_draws = int(1e5)
|
||||
p = tfd.Normal(loc=0., scale=1.)
|
||||
q = tfd.Normal(loc=1., scale=2.)
|
||||
exact_kl_normal_normal = tfd.kl_divergence(p, q)
|
||||
# ==> 0.44314718
|
||||
approx_kl_normal_normal = tfp.monte_carlo.expectation(
|
||||
f=lambda x: p.log_prob(x) - q.log_prob(x),
|
||||
samples=p.sample(num_draws, seed=42),
|
||||
log_prob=p.log_prob,
|
||||
use_reparametrization=(p.reparameterization_type
|
||||
== distribution.FULLY_REPARAMETERIZED))
|
||||
# ==> 0.44632751
|
||||
# Relative Error: <1%
|
||||
|
||||
# Monte-Carlo approximation of non-reparameterized distribution, e.g., Gamma.
|
||||
|
||||
num_draws = int(1e5)
|
||||
p = ds.Gamma(concentration=1., rate=1.)
|
||||
q = ds.Gamma(concentration=2., rate=3.)
|
||||
exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
|
||||
# ==> 0.37999129
|
||||
approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
|
||||
f=lambda x: p.log_prob(x) - q.log_prob(x),
|
||||
samples=p.sample(num_draws, seed=42),
|
||||
log_prob=p.log_prob,
|
||||
use_reparametrization=(p.reparameterization_type
|
||||
== distribution.FULLY_REPARAMETERIZED))
|
||||
# ==> 0.37696719
|
||||
# Relative Error: <1%
|
||||
|
||||
# For comparing the gradients, see `monte_carlo_test.py`.
|
||||
```
|
||||
|
||||
Note: The above example is for illustration only. To compute approximate
|
||||
KL-divergence, the following is preferred:
|
||||
|
||||
```python
|
||||
approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
|
||||
f=bf.kl_reverse,
|
||||
p_log_prob=q.log_prob,
|
||||
q=p,
|
||||
num_draws=num_draws)
|
||||
```
|
||||
|
||||
Args:
|
||||
f: Python callable which can return `f(samples)`.
|
||||
samples: `Tensor` of samples used to form the Monte-Carlo approximation of
|
||||
\\(E_p[f(X)]\\). A batch of samples should be indexed by `axis`
|
||||
dimensions.
|
||||
log_prob: Python callable which can return `log_prob(samples)`. Must
|
||||
correspond to the natural-logarithm of the pdf/pmf of each sample. Only
|
||||
required/used if `use_reparametrization=False`.
|
||||
Default value: `None`.
|
||||
use_reparametrization: Python `bool` indicating that the approximation
|
||||
should use the fact that the gradient of samples is unbiased. Whether
|
||||
`True` or `False`, this arg only affects the gradient of the resulting
|
||||
`approx_expectation`.
|
||||
Default value: `True`.
|
||||
axis: The dimensions to average. If `None`, averages all
|
||||
dimensions.
|
||||
Default value: `0` (the left-most dimension).
|
||||
keep_dims: If True, retains averaged dimensions using size `1`.
|
||||
Default value: `False`.
|
||||
name: A `name_scope` for operations created by this function.
|
||||
Default value: `None` (which implies "expectation").
|
||||
|
||||
Returns:
|
||||
approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation
|
||||
of \\(E_p[f(X)]\\).
|
||||
|
||||
Raises:
|
||||
ValueError: if `f` is not a Python `callable`.
|
||||
ValueError: if `use_reparametrization=False` and `log_prob` is not a Python
|
||||
`callable`.
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, 'expectation', [samples]):
|
||||
if not callable(f):
|
||||
raise ValueError('`f` must be a callable function.')
|
||||
if use_reparametrization:
|
||||
return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims)
|
||||
else:
|
||||
if not callable(log_prob):
|
||||
raise ValueError('`log_prob` must be a callable function.')
|
||||
stop = array_ops.stop_gradient # For readability.
|
||||
x = stop(samples)
|
||||
logpx = log_prob(x)
|
||||
fx = f(x) # Call `f` once in case it has side-effects.
|
||||
# We now rewrite f(x) so that:
|
||||
# `grad[f(x)] := grad[f(x)] + f(x) * grad[logqx]`.
|
||||
# To achieve this, we use a trick that
|
||||
# `h(x) - stop(h(x)) == zeros_like(h(x))`
|
||||
# but its gradient is grad[h(x)].
|
||||
# Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
|
||||
# this trick loses no precision. For more discussion regarding the
|
||||
# relevant portions of the IEEE754 standard, see the StackOverflow
|
||||
# question,
|
||||
# "Is there a floating point value of x, for which x-x == 0 is false?"
|
||||
# http://stackoverflow.com/q/2686644
|
||||
fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx).
|
||||
return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims)
|
||||
|
||||
|
||||
def _sample_mean(values):
|
||||
"""Mean over sample indices. In this module this is always [0]."""
|
||||
return math_ops.reduce_mean(values, axis=[0])
|
||||
|
||||
|
||||
def _sample_max(values):
|
||||
"""Max over sample indices. In this module this is always [0]."""
|
||||
return math_ops.reduce_max(values, axis=[0])
|
||||
|
||||
|
||||
def _get_samples(dist, z, n, seed):
|
||||
"""Check args and return samples."""
|
||||
with ops.name_scope('get_samples', values=[z, n]):
|
||||
if (n is None) == (z is None):
|
||||
raise ValueError(
|
||||
'Must specify exactly one of arguments "n" and "z". Found: '
|
||||
'n = %s, z = %s' % (n, z))
|
||||
if n is not None:
|
||||
return dist.sample(n, seed=seed)
|
||||
else:
|
||||
return ops.convert_to_tensor(z, name='z')
|
@ -1,4 +0,0 @@
|
||||
# Benchmarking Scripts
|
||||
|
||||
This directory tree contains a set of scripts that are useful when benchmarking
|
||||
TensorFlow.
|
@ -1,10 +0,0 @@
|
||||
# Distributed Tensorflow on Google Compute Engine
|
||||
|
||||
The scripts in this directory automate the work to run distributed TensorFlow on
|
||||
a cluster of GCE instances.
|
||||
|
||||
## Pre-work
|
||||
|
||||
Before getting started, while GPUs on GCE are in Alpha, you will need to get
|
||||
your project whitelisted in order to get access. These scripts will not work
|
||||
until then.
|
@ -1,212 +0,0 @@
|
||||
# Cloud Bigtable client for TensorFlow
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_custom_op_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
"tf_py_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow:internal"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "bigtable",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
dso = [
|
||||
":python/ops/_bigtable.so",
|
||||
],
|
||||
kernels = [
|
||||
":bigtable_kernels",
|
||||
":bigtable_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":bigtable_ops",
|
||||
"//tensorflow/contrib/data/python/ops:interleave_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data",
|
||||
],
|
||||
)
|
||||
|
||||
KERNEL_FILES = [
|
||||
"kernels/bigtable_kernels.cc",
|
||||
"kernels/bigtable_lookup_dataset_op.cc",
|
||||
"kernels/bigtable_prefix_key_dataset_op.cc",
|
||||
"kernels/bigtable_range_key_dataset_op.cc",
|
||||
"kernels/bigtable_sample_keys_dataset_op.cc",
|
||||
"kernels/bigtable_sample_key_pairs_dataset_op.cc",
|
||||
"kernels/bigtable_scan_dataset_op.cc",
|
||||
]
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_bigtable.so",
|
||||
srcs = KERNEL_FILES + [
|
||||
"ops/bigtable_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
":bigtable_lib_cc",
|
||||
":bigtable_range_helpers",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "bigtable_ops",
|
||||
deps = [":bigtable_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"bigtable_ops",
|
||||
"bigtable_test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "bigtable_kernels",
|
||||
srcs = KERNEL_FILES,
|
||||
deps = [
|
||||
":bigtable_lib_cc",
|
||||
":bigtable_range_helpers",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
|
||||
],
|
||||
)
|
||||
|
||||
# A library for use in the bigtable kernels.
|
||||
cc_library(
|
||||
name = "bigtable_lib_cc",
|
||||
srcs = ["kernels/bigtable_lib.cc"],
|
||||
hdrs = ["kernels/bigtable_lib.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bigtable_range_helpers",
|
||||
srcs = ["kernels/bigtable_range_helpers.cc"],
|
||||
hdrs = ["kernels/bigtable_range_helpers.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bigtable_test_client",
|
||||
srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
|
||||
hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"@com_github_googleapis_googleapis//:bigtable_protos",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bigtable_test_client_test",
|
||||
srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"],
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":bigtable_test_client",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bigtable_range_helpers_test",
|
||||
size = "small",
|
||||
srcs = ["kernels/bigtable_range_helpers_test.cc"],
|
||||
deps = [
|
||||
":bigtable_range_helpers",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "bigtable_test_ops",
|
||||
deps = [":bigtable_test_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/kernel_tests/_bigtable_test.so",
|
||||
srcs = [
|
||||
"kernels/test_kernels/bigtable_test_client_op.cc",
|
||||
"ops/bigtable_test_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
":bigtable_lib_cc",
|
||||
":bigtable_test_client",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
],
|
||||
)
|
||||
|
||||
# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h
|
||||
cc_library(
|
||||
name = "bigtable_test_kernels",
|
||||
srcs = [
|
||||
"kernels/test_kernels/bigtable_test_client_op.cc",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":bigtable_lib_cc",
|
||||
":bigtable_test_client",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "bigtable_test_py",
|
||||
dso = [
|
||||
":python/kernel_tests/_bigtable_test.so",
|
||||
],
|
||||
kernels = [
|
||||
":bigtable_test_kernels",
|
||||
":bigtable_test_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":bigtable_test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "bigtable_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/bigtable_ops_test.py"],
|
||||
additional_deps = [
|
||||
":bigtable",
|
||||
":bigtable_test_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
tags = ["manual"],
|
||||
)
|
@ -1,346 +0,0 @@
|
||||
# Google Cloud Bigtable
|
||||
|
||||
[Cloud Bigtable](https://cloud.google.com/bigtable/) is a high
|
||||
performance storage system that can store and serve training data. This contrib
|
||||
package contains an experimental integration with TensorFlow.
|
||||
|
||||
> **Status: Highly experimental.** The current implementation is very much in
|
||||
> flux. Please use at your own risk! :-)
|
||||
|
||||
The TensorFlow integration with Cloud Bigtable is optimized for common
|
||||
TensorFlow usage and workloads. It is currently optimized for reading from Cloud
|
||||
Bigtable at high speed, in particular to feed modern accelerators. For
|
||||
general-purpose Cloud Bigtable
|
||||
APIs, see the [official Cloud Bigtable client library documentation][clientdoc].
|
||||
|
||||
[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries
|
||||
|
||||
## Sample Use
|
||||
|
||||
There are three main reading styles supported by the `BigtableTable` class:
|
||||
|
||||
1. **Reading keys**: Read only the row keys in a table. Keys are returned in
|
||||
sorted order from the table. Most key reading operations retrieve all keys
|
||||
in a contiguous range, however the `sample_keys` operation skips keys, and
|
||||
operates on the whole table (and not a contiguous subset).
|
||||
2. **Retrieving a row's values**: Given a row key, look up the data associated
|
||||
with a defined set of columns. This operation takes advantage of Cloud
|
||||
Bigtable's low-latency and excellent support for random access.
|
||||
3. **Scanning ranges**: Given a contiguous range of rows retrieve both the row
|
||||
key and the data associated with a fixed set of columns. This operation
|
||||
takes advantage of Cloud Bigtable's high throughput scans, and is the most
|
||||
efficient way to read data.
|
||||
|
||||
When using the Cloud Bigtable API, the workflow is:
|
||||
|
||||
1. Create a `BigtableClient` object.
|
||||
2. Use the `BigtableClient` to create `BigtableTable` objects corresponding to
|
||||
each table in the Cloud Bigtable instance you would like to access.
|
||||
3. Call methods on the `BigtableTable` object to create `tf.data.Dataset`s to
|
||||
retrieve data.
|
||||
|
||||
The following is an example for how to read all row keys with the prefix
|
||||
`train-`.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
GCP_PROJECT_ID = '<FILL_ME_IN>'
|
||||
BIGTABLE_INSTANCE_ID = '<FILL_ME_IN>'
|
||||
BIGTABLE_TABLE_NAME = '<FILL_ME_IN>'
|
||||
PREFIX = 'train-'
|
||||
|
||||
def main():
|
||||
tf.enable_eager_execution()
|
||||
|
||||
client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID)
|
||||
table = client.table(BIGTABLE_TABLE_NAME)
|
||||
dataset = table.keys_by_prefix_dataset(PREFIX)
|
||||
|
||||
print('Retrieving rows:')
|
||||
row_index = 0
|
||||
for row_key in dataset:
|
||||
print('Row key %d: %s' % (row_index, row_key))
|
||||
row_index += 1
|
||||
print('Finished reading data!')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
### Reading row keys
|
||||
|
||||
Read only the row keys in a table. Keys are returned in sorted order from the
|
||||
table. Most key reading operations retrieve all keys in a contiguous range,
|
||||
however the `sample_keys` operation skips keys, and operates on the whole table
|
||||
(and not a contiguous subset).
|
||||
|
||||
There are 3 methods to retrieve row keys:
|
||||
|
||||
- `table.keys_by_range_dataset(start, end)`: Retrieve row keys starting with
|
||||
`start`, and ending with `end`. The range is "half-open", and thus it
|
||||
includes `start` if `start` is present in the table. It does not include
|
||||
`end`.
|
||||
- `table.keys_by_prefix_dataset(prefix)`: Retrieves all row keys that start
|
||||
with `prefix`. It includes the row key `prefix` if present in the table.
|
||||
- `table.sample_keys()`: Retrieves a sampling of keys from the underlying
|
||||
table. This is often useful in conjunction with parallel scans.
|
||||
|
||||
### Reading cell values given a row key
|
||||
|
||||
Given a dataset producing row keys, you can use the `table.lookup_columns`
|
||||
transformation to retrieve values. Example:
|
||||
|
||||
```python
|
||||
key_dataset = tf.data.Dataset.from_tensor_slices([
|
||||
'row_key_1',
|
||||
'other_row_key',
|
||||
'final_row_key',
|
||||
])
|
||||
values_dataset = key_dataset.apply(
|
||||
table.lookup_columns(('my_column_family', 'column_name'),
|
||||
('other_cf', 'col')))
|
||||
training_data = values_dataset.map(my_parsing_function) # ...
|
||||
```
|
||||
|
||||
### Scanning ranges
|
||||
Given a contiguous range of rows retrieve both the row key and the data
|
||||
associated with a fixed set of columns. Scanning is the most efficient way to
|
||||
retrieve data from Cloud Bigtable and is thus a very common API for high
|
||||
performance data pipelines. To construct a scanning `tf.data.Dataset` from a
|
||||
`BigtableTable` object, call one of the following methods:
|
||||
|
||||
- `table.scan_prefix(prefix, ...)`
|
||||
- `table.scan_range(start, end, ...)`
|
||||
- `table.parallel_scan_prefix(prefix, ...)`
|
||||
- `table.parallel_scan_range(start, end, ...)`
|
||||
|
||||
Aside from the specification of the contiguous range of rows, they all take the
|
||||
following arguments:
|
||||
|
||||
- `probability`: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
|
||||
A non-1 value indicates to probabilistically sample rows with the
|
||||
provided probability.
|
||||
- `columns`: The columns to read. (See below.)
|
||||
- `**kwargs`: The columns to read. (See below.)
|
||||
|
||||
In addition the two parallel operations accept the following optional argument:
|
||||
`num_parallel_scans` which configures the number of parallel Cloud Bigtable scan
|
||||
operations to run. A reasonable default is automatically chosen for small
|
||||
Cloud Bigtable clusters. If you have a large cluster, or an extremely demanding
|
||||
workload, you can tune this value to optimize performance.
|
||||
|
||||
#### Specifying columns to read when scanning
|
||||
|
||||
All of the scan operations allow you to specify the column family and columns
|
||||
in the same ways.
|
||||
|
||||
##### Using `columns`
|
||||
|
||||
The first way to specify the data to read is via the `columns` parameter. The
|
||||
value should be a tuple (or list of tuples) of strings. The first string in the
|
||||
tuple is the column family, and the second string in the tuple is the column
|
||||
qualifier.
|
||||
|
||||
##### Using `**kwargs`
|
||||
|
||||
The second way to specify the data to read is via the `**kwargs` parameter,
|
||||
which you can use to specify keyword arguments corresponding to the columns that
|
||||
you want to read. The keyword to use is the column family name, and the argument
|
||||
value should be either a string, or a tuple of strings, specifying the column
|
||||
qualifiers (column names).
|
||||
|
||||
Although using `**kwargs` has the advantage of requiring less typing, it is not
|
||||
future-proof in all cases. (If we add a new parameter to the scan functions that
|
||||
has the same name as your column family, your code will break.)
|
||||
|
||||
##### Examples
|
||||
|
||||
Below are two equivalent snippets for how to specify which columns to read:
|
||||
|
||||
```python
|
||||
ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
|
||||
("cfa", "c2"),
|
||||
("cfb", "c3")])
|
||||
ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
|
||||
```
|
||||
|
||||
In this example, we are reading 3 columns from a total of 2 column families.
|
||||
From the `cfa` column family, we are reading columns `c1`, and `c2`. From the
|
||||
second column family (`cfb`), we are reading `c3`. Both `ds1` and `ds2` will
|
||||
output elements of the following types (`tf.string`, `tf.string`, `tf.string`,
|
||||
`tf.string`). The first `tf.string` is the row key, the second `tf.string` is
|
||||
the latest data in cell `cfa:c1`, the third corresponds to `cfa:c2`, and the
|
||||
final one is `cfb:c3`.
|
||||
|
||||
#### Determinism when scanning
|
||||
|
||||
While the non-parallel scan operations are fully deterministic, the parallel
|
||||
scan operations are not. If you would like to scan in parallel without losing
|
||||
determinism, you can build up the `parallel_interleave` yourself. As an example,
|
||||
say we wanted to scan all rows between `training_data_00000`, and
|
||||
`training_data_90000`, we can use the following code snippet:
|
||||
|
||||
```python
|
||||
table = # ...
|
||||
columns = [('cf1', 'col1'), ('cf1', 'col2')]
|
||||
NUM_PARALLEL_READS = # ...
|
||||
ds = tf.data.Dataset.range(9).shuffle(10)
|
||||
def interleave_fn(index):
|
||||
# Given a starting index, create 2 strings to be the start and end
|
||||
start_idx = index
|
||||
end_idx = index + 1
|
||||
start_idx_str = tf.as_string(start_idx * 10000, width=5, fill='0')
|
||||
end_idx_str = tf.as_string(end_idx * 10000, width=5, fill='0')
|
||||
start = tf.string_join(['training_data_', start_idx_str])
|
||||
end = tf.string_join(['training_data_', end_idx_str])
|
||||
return table.scan_range(start_idx, end_idx, columns=columns)
|
||||
ds = ds.apply(tf.data.experimental.parallel_interleave(
|
||||
interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1))
|
||||
```
|
||||
|
||||
> Note: you should divide up the key range into more sub-ranges for increased
|
||||
> parallelism.
|
||||
|
||||
## Writing to Cloud Bigtable
|
||||
|
||||
In order to simplify getting started, this package provides basic support for
|
||||
writing data into Cloud Bigtable.
|
||||
|
||||
> Note: The implementation is not optimized for performance! Please consider
|
||||
> using alternative frameworks such as Apache Beam / Cloud Dataflow for
|
||||
> production workloads.
|
||||
|
||||
Below is an example for how to write a trivial dataset into Cloud Bigtable.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
GCP_PROJECT_ID = '<FILL_ME_IN>'
|
||||
BIGTABLE_INSTANCE_ID = '<FILL_ME_IN>'
|
||||
BIGTABLE_TABLE_NAME = '<FILL_ME_IN>'
|
||||
COLUMN_FAMILY = '<FILL_ME_IN>'
|
||||
COLUMN_QUALIFIER = '<FILL_ME_IN>'
|
||||
|
||||
def make_dataset():
|
||||
"""Makes a dataset to write to Cloud Bigtable."""
|
||||
return tf.data.Dataset.from_tensor_slices([
|
||||
'training_data_1',
|
||||
'training_data_2',
|
||||
'training_data_3',
|
||||
])
|
||||
|
||||
def make_row_key_dataset():
|
||||
"""Makes a dataset of strings used for row keys.
|
||||
|
||||
The strings are of the form: `fake-data-` followed by a sequential counter.
|
||||
For example, this dataset would contain the following elements:
|
||||
|
||||
- fake-data-00000001
|
||||
- fake-data-00000002
|
||||
- ...
|
||||
- fake-data-23498103
|
||||
"""
|
||||
counter_dataset = tf.data.experimental.Counter()
|
||||
width = 8
|
||||
row_key_prefix = 'fake-data-'
|
||||
ds = counter_dataset.map(lambda index: tf.as_string(index,
|
||||
width=width,
|
||||
fill='0'))
|
||||
ds = ds.map(lambda idx_str: tf.string_join([row_key_prefix, idx_str]))
|
||||
return ds
|
||||
|
||||
|
||||
def main():
|
||||
client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID)
|
||||
table = client.table(BIGTABLE_TABLE_NAME)
|
||||
dataset = make_dataset()
|
||||
index_dataset = make_row_key_dataset()
|
||||
aggregate_dataset = tf.data.Dataset.zip((index_dataset, dataset))
|
||||
write_op = table.write(aggregate_dataset, column_families=[COLUMN_FAMILY],
|
||||
columns=[COLUMN_QUALIFIER])
|
||||
|
||||
with tf.Session() as sess:
|
||||
print('Starting transfer.')
|
||||
sess.run(write_op)
|
||||
print('Transfer complete.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
## Sample applications and architectures
|
||||
|
||||
While most machine learning applications are well suited by a high performance
|
||||
distributed file system, there are certain applications where using Cloud
|
||||
Bigtable works extremely well.
|
||||
|
||||
### Perfect Shuffling
|
||||
|
||||
Normally, training data is stored in flat files, and a combination of
|
||||
(1) `tf.data.Dataset.interleave` (or `parallel_interleave`), (2)
|
||||
`tf.data.Dataset.shuffle`, and (3) writing the data in an unsorted order in the
|
||||
data files in the first place, provides enough randomization to ensure models
|
||||
train efficiently. However, if you would like perfect shuffling, you can use
|
||||
Cloud Bigtable's low-latency random access capabilities. Create a
|
||||
`tf.data.Dataset` that generates the keys in a perfectly random order (or read
|
||||
all the keys into memory and use a shuffle buffer sized to fit all of them for a
|
||||
perfect random shuffle using `tf.data.Dataset.shuffle`), and then use
|
||||
`lookup_columns` to retrieve the training data.
|
||||
|
||||
### Distributed Reinforcement Learning
|
||||
|
||||
Sophisticated reinforcement learning algorithms are commonly trained across a
|
||||
distributed cluster. (See [IMPALA by DeepMind][impala].) One part of the cluster
|
||||
runs self-play, while the other part of the cluster learns a new version of the
|
||||
model based on the training data generated by self-play. The new model version
|
||||
is then distributed to the self-play half of the cluster, and new training data
|
||||
is generated to continue the cycle.
|
||||
|
||||
In such a configuration, because there is value in training on the freshest
|
||||
examples, a storage service like Cloud Bigtable can be used to store and
|
||||
serve the generated training data. When using Cloud Bigtable, there is no need
|
||||
to aggregate the examples into large batch files, but the examples can instead
|
||||
be written as soon as they are generated, and then retrieved at high speed.
|
||||
|
||||
[impala]: https://arxiv.org/abs/1802.01561
|
||||
|
||||
## Common Gotchas!
|
||||
|
||||
### gRPC Certificates
|
||||
|
||||
If you encounter a log line that includes the following:
|
||||
|
||||
```
|
||||
"description":"Failed to load file", [...],
|
||||
"filename":"/usr/share/grpc/roots.pem"
|
||||
```
|
||||
|
||||
you can solve it via either of the following approaches:
|
||||
|
||||
* copy the [gRPC `roots.pem` file][grpcPem] to
|
||||
`/usr/share/grpc/roots.pem` on your local machine, which is the default
|
||||
location where gRPC will look for this file
|
||||
* export the environment variable `GRPC_DEFAULT_SSL_ROOTS_FILE_PATH` to point to
|
||||
the full path of the gRPC `roots.pem` file on your file system if it's in a
|
||||
different location
|
||||
|
||||
[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem
|
||||
|
||||
### Permission denied errors
|
||||
|
||||
The TensorFlow Cloud Bigtable client will search for credentials to use in the
|
||||
process's environment. It will use the first credentials it finds if multiple
|
||||
are available.
|
||||
|
||||
- **Compute Engine**: When running on Compute Engine, the client will often use
|
||||
the service account from the virtual machine's metadata service. Be sure to
|
||||
authorize your Compute Engine VM to have access to the Cloud Bigtable service
|
||||
when creating your VM, or [update the VM's scopes][update-vm-scopes] on a
|
||||
running VM if you run into this issue.
|
||||
- **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service
|
||||
account dedicated to your GCP project. Ensure the service account has been
|
||||
authorized via the Cloud Console to access your Cloud Bigtable instances.
|
||||
|
||||
[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes
|
@ -1,39 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Cloud Bigtable Client for TensorFlow.
|
||||
|
||||
This contrib package allows TensorFlow to interface directly with Cloud Bigtable
|
||||
for high-speed data loading.
|
||||
|
||||
@@BigtableClient
|
||||
@@BigtableTable
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
|
||||
from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'BigtableClient',
|
||||
'BigtableTable',
|
||||
]
|
||||
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -1,360 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class BigtableClientOp : public OpKernel {
|
||||
public:
|
||||
explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
|
||||
OP_REQUIRES(ctx, !project_id_.empty(),
|
||||
errors::InvalidArgument("project_id must be non-empty"));
|
||||
OP_REQUIRES(ctx, !instance_id_.empty(),
|
||||
errors::InvalidArgument("instance_id must be non-empty"));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_));
|
||||
// If left unset by the client code, set it to a default of 100. Note: the
|
||||
// cloud-cpp default of 4 concurrent connections is far too low for high
|
||||
// performance streaming.
|
||||
if (connection_pool_size_ == -1) {
|
||||
connection_pool_size_ = 100;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
|
||||
&max_receive_message_size_));
|
||||
// If left unset by the client code, set it to a default of 100. Note: the
|
||||
// cloud-cpp default of 4 concurrent connections is far too low for high
|
||||
// performance streaming.
|
||||
if (max_receive_message_size_ == -1) {
|
||||
max_receive_message_size_ = 1 << 24; // 16 MBytes
|
||||
}
|
||||
OP_REQUIRES(ctx, max_receive_message_size_ > 0,
|
||||
errors::InvalidArgument("connection_pool_size must be > 0"));
|
||||
}
|
||||
|
||||
~BigtableClientOp() override {
|
||||
if (cinfo_.resource_is_private_to_kernel()) {
|
||||
if (!cinfo_.resource_manager()
|
||||
->Delete<BigtableClientResource>(cinfo_.container(),
|
||||
cinfo_.name())
|
||||
.ok()) {
|
||||
// Do nothing; the resource can have been deleted by session resets.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
if (!initialized_) {
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
||||
BigtableClientResource* resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
mgr->LookupOrCreate<BigtableClientResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, ctx](
|
||||
BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
auto client_options =
|
||||
google::cloud::bigtable::ClientOptions()
|
||||
.set_connection_pool_size(connection_pool_size_)
|
||||
.set_data_endpoint("batch-bigtable.googleapis.com");
|
||||
auto channel_args = client_options.channel_arguments();
|
||||
channel_args.SetMaxReceiveMessageSize(
|
||||
max_receive_message_size_);
|
||||
channel_args.SetUserAgentPrefix("tensorflow");
|
||||
channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0);
|
||||
channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000);
|
||||
client_options.set_channel_arguments(channel_args);
|
||||
std::shared_ptr<google::cloud::bigtable::DataClient> client =
|
||||
google::cloud::bigtable::CreateDefaultDataClient(
|
||||
project_id_, instance_id_, std::move(client_options));
|
||||
*ret = new BigtableClientResource(project_id_, instance_id_,
|
||||
std::move(client));
|
||||
return Status::OK();
|
||||
}));
|
||||
core::ScopedUnref resource_cleanup(resource);
|
||||
initialized_ = true;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<BigtableClientResource>()));
|
||||
}
|
||||
|
||||
private:
|
||||
string project_id_;
|
||||
string instance_id_;
|
||||
int64 connection_pool_size_;
|
||||
int32 max_receive_message_size_;
|
||||
|
||||
mutex mu_;
|
||||
ContainerInfo cinfo_ GUARDED_BY(mu_);
|
||||
bool initialized_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU),
|
||||
BigtableClientOp);
|
||||
|
||||
class BigtableTableOp : public OpKernel {
|
||||
public:
|
||||
explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_));
|
||||
OP_REQUIRES(ctx, !table_.empty(),
|
||||
errors::InvalidArgument("table_name must be non-empty"));
|
||||
}
|
||||
|
||||
~BigtableTableOp() override {
|
||||
if (cinfo_.resource_is_private_to_kernel()) {
|
||||
if (!cinfo_.resource_manager()
|
||||
->Delete<BigtableTableResource>(cinfo_.container(),
|
||||
cinfo_.name())
|
||||
.ok()) {
|
||||
// Do nothing; the resource can have been deleted by session resets.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
if (!initialized_) {
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
||||
|
||||
core::RefCountPtr<BigtableClientResource> client_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
|
||||
|
||||
BigtableTableResource* resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
mgr->LookupOrCreate<BigtableTableResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, &client_resource](BigtableTableResource** ret) {
|
||||
*ret = new BigtableTableResource(
|
||||
client_resource.get(), table_);
|
||||
return Status::OK();
|
||||
}));
|
||||
initialized_ = true;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<BigtableTableResource>()));
|
||||
}
|
||||
|
||||
private:
|
||||
string table_; // Note: this is const after construction.
|
||||
|
||||
mutex mu_;
|
||||
ContainerInfo cinfo_ GUARDED_BY(mu_);
|
||||
bool initialized_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
|
||||
BigtableTableOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class ToBigtableOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit ToBigtableOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
thread_pool_(new thread::ThreadPool(
|
||||
ctx->env(), ThreadOptions(),
|
||||
strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())),
|
||||
/* num_threads = */ 1, /* low_latency_hint = */ false)) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
thread_pool_->Schedule([this, ctx, done]() {
|
||||
const Tensor* column_families_tensor;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->input("column_families", &column_families_tensor), done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, column_families_tensor->dims() == 1,
|
||||
errors::InvalidArgument("`column_families` must be a vector."), done);
|
||||
|
||||
const Tensor* columns_tensor;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done);
|
||||
OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1,
|
||||
errors::InvalidArgument("`columns` must be a vector."),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx,
|
||||
columns_tensor->NumElements() ==
|
||||
column_families_tensor->NumElements(),
|
||||
errors::InvalidArgument("len(column_families) != len(columns)"),
|
||||
done);
|
||||
|
||||
std::vector<string> column_families;
|
||||
column_families.reserve(column_families_tensor->NumElements());
|
||||
std::vector<string> columns;
|
||||
columns.reserve(column_families_tensor->NumElements());
|
||||
for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
|
||||
column_families.push_back(column_families_tensor->flat<tstring>()(i));
|
||||
columns.push_back(columns_tensor->flat<tstring>()(i));
|
||||
}
|
||||
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator",
|
||||
&iterator),
|
||||
done);
|
||||
|
||||
int64 timestamp_int;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ParseScalarArgument<int64>(ctx, "timestamp", ×tamp_int),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1,
|
||||
errors::InvalidArgument("timestamp must be >= -1"),
|
||||
done);
|
||||
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence = false;
|
||||
do {
|
||||
::google::cloud::bigtable::BulkMutation mutation;
|
||||
// TODO(saeta): Make # of mutations configurable.
|
||||
for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
iterator->GetNext(IteratorContext(ctx),
|
||||
&components, &end_of_sequence),
|
||||
done);
|
||||
if (!end_of_sequence) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
CreateMutation(std::move(components), column_families, columns,
|
||||
timestamp_int, &mutation),
|
||||
done);
|
||||
}
|
||||
components.clear();
|
||||
}
|
||||
::google::cloud::Status mutation_status;
|
||||
std::vector<::google::cloud::bigtable::FailedMutation> failures =
|
||||
resource->table().BulkApply(mutation);
|
||||
if (!failures.empty()) {
|
||||
mutation_status = failures.front().status();
|
||||
if (!mutation_status.ok()) {
|
||||
LOG(ERROR) << "Failure applying mutation: "
|
||||
<< mutation_status.code() << " - "
|
||||
<< mutation_status.message() << ".";
|
||||
}
|
||||
::google::bigtable::v2::MutateRowsRequest request;
|
||||
mutation.MoveTo(&request);
|
||||
for (const auto& failure : failures) {
|
||||
LOG(ERROR) << "Failure applying mutation on row ("
|
||||
<< failure.original_index() << "): "
|
||||
<< request.entries(failure.original_index()).row_key()
|
||||
<< " - error: " << failure.status().message() << ".";
|
||||
}
|
||||
}
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, failures.empty(),
|
||||
errors::Unknown("Failure while writing to Cloud Bigtable: ",
|
||||
mutation_status.code(), " - ",
|
||||
mutation_status.message(),
|
||||
"; # of mutation failures: ", failures.size(),
|
||||
". See the log for the specific error details."),
|
||||
done);
|
||||
} while (!end_of_sequence);
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
static string SanitizeThreadSuffix(string suffix) {
|
||||
string clean;
|
||||
for (int i = 0; i < suffix.size(); ++i) {
|
||||
const char ch = suffix[i];
|
||||
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
|
||||
(ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
|
||||
clean += ch;
|
||||
} else {
|
||||
clean += '_';
|
||||
}
|
||||
}
|
||||
return clean;
|
||||
}
|
||||
|
||||
Status CreateMutation(
|
||||
std::vector<Tensor> tensors, const std::vector<string>& column_families,
|
||||
const std::vector<string>& columns, int64 timestamp_int,
|
||||
::google::cloud::bigtable::BulkMutation* bulk_mutation) {
|
||||
if (tensors.size() != column_families.size() + 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Iterator produced a set of Tensors shorter than expected");
|
||||
}
|
||||
::google::cloud::bigtable::SingleRowMutation mutation(
|
||||
std::move(tensors[0].scalar<tstring>()()));
|
||||
std::chrono::milliseconds timestamp(timestamp_int);
|
||||
for (size_t i = 1; i < tensors.size(); ++i) {
|
||||
if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
|
||||
return errors::Internal("Output tensor ", i, " was not a scalar");
|
||||
}
|
||||
if (timestamp_int == -1) {
|
||||
mutation.emplace_back(::google::cloud::bigtable::SetCell(
|
||||
column_families[i - 1], columns[i - 1],
|
||||
std::move(tensors[i].scalar<tstring>()())));
|
||||
} else {
|
||||
mutation.emplace_back(::google::cloud::bigtable::SetCell(
|
||||
column_families[i - 1], columns[i - 1], timestamp,
|
||||
std::move(tensors[i].scalar<tstring>()())));
|
||||
}
|
||||
}
|
||||
bulk_mutation->emplace_back(std::move(mutation));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
|
||||
T* output) {
|
||||
const Tensor* argument_t;
|
||||
TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
|
||||
if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
|
||||
return errors::InvalidArgument(argument_name, " must be a scalar");
|
||||
}
|
||||
*output = argument_t->scalar<T>()();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
|
||||
ToBigtableOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,79 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
::tensorflow::error::Code GcpErrorCodeToTfErrorCode(
|
||||
::google::cloud::StatusCode code) {
|
||||
switch (code) {
|
||||
case ::google::cloud::StatusCode::kOk:
|
||||
return ::tensorflow::error::OK;
|
||||
case ::google::cloud::StatusCode::kCancelled:
|
||||
return ::tensorflow::error::CANCELLED;
|
||||
case ::google::cloud::StatusCode::kUnknown:
|
||||
return ::tensorflow::error::UNKNOWN;
|
||||
case ::google::cloud::StatusCode::kInvalidArgument:
|
||||
return ::tensorflow::error::INVALID_ARGUMENT;
|
||||
case ::google::cloud::StatusCode::kDeadlineExceeded:
|
||||
return ::tensorflow::error::DEADLINE_EXCEEDED;
|
||||
case ::google::cloud::StatusCode::kNotFound:
|
||||
return ::tensorflow::error::NOT_FOUND;
|
||||
case ::google::cloud::StatusCode::kAlreadyExists:
|
||||
return ::tensorflow::error::ALREADY_EXISTS;
|
||||
case ::google::cloud::StatusCode::kPermissionDenied:
|
||||
return ::tensorflow::error::PERMISSION_DENIED;
|
||||
case ::google::cloud::StatusCode::kUnauthenticated:
|
||||
return ::tensorflow::error::UNAUTHENTICATED;
|
||||
case ::google::cloud::StatusCode::kResourceExhausted:
|
||||
return ::tensorflow::error::RESOURCE_EXHAUSTED;
|
||||
case ::google::cloud::StatusCode::kFailedPrecondition:
|
||||
return ::tensorflow::error::FAILED_PRECONDITION;
|
||||
case ::google::cloud::StatusCode::kAborted:
|
||||
return ::tensorflow::error::ABORTED;
|
||||
case ::google::cloud::StatusCode::kOutOfRange:
|
||||
return ::tensorflow::error::OUT_OF_RANGE;
|
||||
case ::google::cloud::StatusCode::kUnimplemented:
|
||||
return ::tensorflow::error::UNIMPLEMENTED;
|
||||
case ::google::cloud::StatusCode::kInternal:
|
||||
return ::tensorflow::error::INTERNAL;
|
||||
case ::google::cloud::StatusCode::kUnavailable:
|
||||
return ::tensorflow::error::UNAVAILABLE;
|
||||
case ::google::cloud::StatusCode::kDataLoss:
|
||||
return ::tensorflow::error::DATA_LOSS;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GcpStatusToTfStatus(const ::google::cloud::Status& status) {
|
||||
if (status.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return Status(
|
||||
GcpErrorCodeToTfErrorCode(status.code()),
|
||||
strings::StrCat("Error reading from Cloud Bigtable: ", status.message()));
|
||||
}
|
||||
|
||||
string RegexFromStringSet(const std::vector<tstring>& strs) {
|
||||
CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty.";
|
||||
std::unordered_set<tstring> uniq(strs.begin(), strs.end());
|
||||
if (uniq.size() == 1) {
|
||||
return *uniq.begin();
|
||||
}
|
||||
return absl::StrJoin(uniq, "|");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,153 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
|
||||
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
|
||||
|
||||
#include "google/cloud/bigtable/data_client.h"
|
||||
#include "google/cloud/bigtable/table.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GcpStatusToTfStatus(const ::google::cloud::Status& status);
|
||||
|
||||
string RegexFromStringSet(const std::vector<tstring>& strs);
|
||||
|
||||
class BigtableClientResource : public ResourceBase {
|
||||
public:
|
||||
BigtableClientResource(
|
||||
string project_id, string instance_id,
|
||||
std::shared_ptr<google::cloud::bigtable::DataClient> client)
|
||||
: project_id_(std::move(project_id)),
|
||||
instance_id_(std::move(instance_id)),
|
||||
client_(std::move(client)) {}
|
||||
|
||||
std::shared_ptr<google::cloud::bigtable::DataClient> get_client() {
|
||||
return client_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat("BigtableClientResource(project_id: ", project_id_,
|
||||
", instance_id: ", instance_id_, ")");
|
||||
}
|
||||
|
||||
private:
|
||||
const string project_id_;
|
||||
const string instance_id_;
|
||||
std::shared_ptr<google::cloud::bigtable::DataClient> client_;
|
||||
};
|
||||
|
||||
class BigtableTableResource : public ResourceBase {
|
||||
public:
|
||||
BigtableTableResource(BigtableClientResource* client, string table_name)
|
||||
: client_(client),
|
||||
table_name_(std::move(table_name)),
|
||||
table_(client->get_client(), table_name_,
|
||||
google::cloud::bigtable::AlwaysRetryMutationPolicy()) {
|
||||
client_->Ref();
|
||||
}
|
||||
|
||||
~BigtableTableResource() override { client_->Unref(); }
|
||||
|
||||
::google::cloud::bigtable::Table& table() { return table_; }
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat(
|
||||
"BigtableTableResource(client: ", client_->DebugString(),
|
||||
", table: ", table_name_, ")");
|
||||
}
|
||||
|
||||
private:
|
||||
BigtableClientResource* client_; // Ownes one ref.
|
||||
const string table_name_;
|
||||
::google::cloud::bigtable::Table table_;
|
||||
};
|
||||
|
||||
namespace data {
|
||||
|
||||
// BigtableReaderDatasetIterator is an abstract class for iterators from
|
||||
// datasets that are "readers" (source datasets, not transformation datasets)
|
||||
// that read from Bigtable.
|
||||
template <typename Dataset>
|
||||
class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit BigtableReaderDatasetIterator(
|
||||
const typename DatasetIterator<Dataset>::Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(EnsureIteratorInitialized());
|
||||
if (iterator_ == reader_->end()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
if (!*iterator_) {
|
||||
return GcpStatusToTfStatus(iterator_->status());
|
||||
}
|
||||
*end_of_sequence = false;
|
||||
google::cloud::bigtable::Row& row = **iterator_;
|
||||
Status s = ParseRow(ctx, row, out_tensors);
|
||||
// Ensure we always advance.
|
||||
++iterator_;
|
||||
return s;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0;
|
||||
virtual ::google::cloud::bigtable::Filter MakeFilter() = 0;
|
||||
virtual Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) = 0;
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented("RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (reader_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto rows = MakeRowRange();
|
||||
auto filter = MakeFilter();
|
||||
|
||||
// Note: the this in `this->dataset()` below is necessary due to namespace
|
||||
// name conflicts.
|
||||
reader_.reset(new ::google::cloud::bigtable::RowReader(
|
||||
this->dataset()->table()->table().ReadRows(rows, filter)));
|
||||
iterator_ = reader_->begin();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_);
|
||||
::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
|
@ -1,248 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
core::RefCountPtr<BigtableTableResource> table;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
|
||||
|
||||
std::vector<tstring> column_families;
|
||||
std::vector<tstring> columns;
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
|
||||
&column_families));
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
|
||||
OP_REQUIRES(
|
||||
ctx, column_families.size() == columns.size(),
|
||||
errors::InvalidArgument("len(columns) != len(column_families)"));
|
||||
|
||||
const uint64 num_outputs = columns.size() + 1;
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
output_shapes.reserve(num_outputs);
|
||||
DataTypeVector output_types;
|
||||
output_types.reserve(num_outputs);
|
||||
for (uint64 i = 0; i < num_outputs; ++i) {
|
||||
output_shapes.push_back({});
|
||||
output_types.push_back(DT_STRING);
|
||||
}
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, input, table.get(), std::move(column_families),
|
||||
std::move(columns), output_types, std::move(output_shapes));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
BigtableTableResource* table,
|
||||
std::vector<tstring> column_families,
|
||||
std::vector<tstring> columns,
|
||||
const DataTypeVector& output_types,
|
||||
std::vector<PartialTensorShape> output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
table_(table),
|
||||
column_families_(std::move(column_families)),
|
||||
columns_(std::move(columns)),
|
||||
output_types_(output_types),
|
||||
output_shapes_(std::move(output_shapes)),
|
||||
filter_(MakeFilter(column_families_, columns_)) {
|
||||
table_->Ref();
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override {
|
||||
table_->Unref();
|
||||
input_->Unref();
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return output_types_;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtableLookupDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
static ::google::cloud::bigtable::Filter MakeFilter(
|
||||
const std::vector<tstring>& column_families,
|
||||
const std::vector<tstring>& columns) {
|
||||
string column_family_regex = RegexFromStringSet(column_families);
|
||||
string column_regex = RegexFromStringSet(columns);
|
||||
|
||||
return ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1),
|
||||
::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex),
|
||||
::google::cloud::bigtable::Filter::ColumnRegex(column_regex));
|
||||
}
|
||||
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_); // Sequence requests.
|
||||
std::vector<Tensor> input_tensors;
|
||||
TF_RETURN_IF_ERROR(
|
||||
input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
|
||||
if (*end_of_sequence) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (input_tensors.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Upstream iterator (", dataset()->input_->DebugString(),
|
||||
") did not produce a single `tf.string` `tf.Tensor`. It "
|
||||
"produced ",
|
||||
input_tensors.size(), " tensors.");
|
||||
}
|
||||
if (input_tensors[0].NumElements() == 0) {
|
||||
return errors::InvalidArgument("Upstream iterator (",
|
||||
dataset()->input_->DebugString(),
|
||||
") return an empty set of keys.");
|
||||
}
|
||||
if (input_tensors[0].NumElements() == 1) {
|
||||
// Single key lookup.
|
||||
::google::cloud::StatusOr<
|
||||
std::pair<bool, ::google::cloud::bigtable::Row>>
|
||||
row = dataset()->table_->table().ReadRow(
|
||||
input_tensors[0].scalar<tstring>()(), dataset()->filter_);
|
||||
if (!row.ok()) {
|
||||
return GcpStatusToTfStatus(row.status());
|
||||
}
|
||||
if (!row->first) {
|
||||
return errors::DataLoss("Row key '",
|
||||
input_tensors[0].scalar<tstring>()(),
|
||||
"' not found.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors));
|
||||
} else {
|
||||
// Batched get.
|
||||
return errors::Unimplemented(
|
||||
"BigtableLookupDataset doesn't yet support batched retrieval.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) {
|
||||
out_tensors->reserve(dataset()->columns_.size() + 1);
|
||||
Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
|
||||
out_tensors->emplace_back(std::move(row_key_tensor));
|
||||
|
||||
if (row.cells().size() > 2 * dataset()->columns_.size()) {
|
||||
LOG(WARNING) << "An excessive number of columns ("
|
||||
<< row.cells().size()
|
||||
<< ") were retrieved when reading row: "
|
||||
<< row.row_key();
|
||||
}
|
||||
|
||||
for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
|
||||
Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
bool found_column = false;
|
||||
for (auto cell_itr = row.cells().begin();
|
||||
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
|
||||
if (cell_itr->family_name() == dataset()->column_families_[i] &&
|
||||
tstring(cell_itr->column_qualifier()) ==
|
||||
dataset()->columns_[i]) {
|
||||
col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
|
||||
found_column = true;
|
||||
}
|
||||
}
|
||||
if (!found_column) {
|
||||
return errors::DataLoss("Column ", dataset()->column_families_[i],
|
||||
":", dataset()->columns_[i],
|
||||
" not found in row: ", row.row_key());
|
||||
}
|
||||
out_tensors->emplace_back(std::move(col_tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
BigtableTableResource* table_;
|
||||
const std::vector<tstring> column_families_;
|
||||
const std::vector<tstring> columns_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const ::google::cloud::bigtable::Filter filter_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
|
||||
BigtableLookupDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,121 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
tstring prefix;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
|
||||
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
|
||||
string prefix)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
table_(table),
|
||||
prefix_(std::move(prefix)) {
|
||||
table_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { table_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtablePrefixKeyDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: BigtableReaderDatasetIterator<Dataset>(params) {}
|
||||
|
||||
::google::cloud::bigtable::RowRange MakeRowRange() override {
|
||||
return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_);
|
||||
}
|
||||
::google::cloud::bigtable::Filter MakeFilter() override {
|
||||
return ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::CellsRowLimit(1),
|
||||
::google::cloud::bigtable::Filter::StripValueTransformer());
|
||||
}
|
||||
Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) override {
|
||||
Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
output_tensor.scalar<tstring>()() = tstring(row.row_key());
|
||||
out_tensors->emplace_back(std::move(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
BigtableTableResource* const table_;
|
||||
const string prefix_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
|
||||
BigtablePrefixKeyDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
string MakePrefixEndKey(const string& prefix) {
|
||||
string end = prefix;
|
||||
while (true) {
|
||||
if (end.empty()) {
|
||||
return end;
|
||||
}
|
||||
++end[end.size() - 1];
|
||||
if (end[end.size() - 1] == 0) {
|
||||
// Handle wraparound case.
|
||||
end = end.substr(0, end.size() - 1);
|
||||
} else {
|
||||
return end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) {
|
||||
string end = MakePrefixEndKey(prefix);
|
||||
VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix
|
||||
<< ", with end key: " << end;
|
||||
return MultiModeKeyRange(std::move(prefix), std::move(end));
|
||||
}
|
||||
|
||||
/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin,
|
||||
string end) {
|
||||
return MultiModeKeyRange(std::move(begin), std::move(end));
|
||||
}
|
||||
|
||||
const string& MultiModeKeyRange::begin_key() const { return begin_; }
|
||||
|
||||
const string& MultiModeKeyRange::end_key() const { return end_; }
|
||||
|
||||
bool MultiModeKeyRange::contains_key(StringPiece key) const {
|
||||
if (StringPiece(begin_) > key) {
|
||||
return false;
|
||||
}
|
||||
if (StringPiece(end_) <= key && !end_.empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,68 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
|
||||
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Represents a continuous range of keys defined by either a prefix or a range.
|
||||
//
|
||||
// Ranges are represented as "half-open", where the beginning key is included
|
||||
// in the range, and the end_key is the first excluded key after the range.
|
||||
//
|
||||
// The range of keys can be specified either by a key prefix, or by an explicit
|
||||
// begin key and end key. All methods on this class are valid no matter which
|
||||
// way the range was specified.
|
||||
//
|
||||
// Example:
|
||||
// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix");
|
||||
// if (range.contains_key("myPrefixedKey")) {
|
||||
// LOG(INFO) << "range from " << range.begin_key() << " to "
|
||||
// << range.end_key() << "contains \"myPrefixedKey\"";
|
||||
// }
|
||||
// if (!range.contains_key("randomKey")) {
|
||||
// LOG(INFO) << "range does not contain \"randomKey\"";
|
||||
// }
|
||||
// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key");
|
||||
class MultiModeKeyRange {
|
||||
public:
|
||||
static MultiModeKeyRange FromPrefix(string prefix);
|
||||
static MultiModeKeyRange FromRange(string begin, string end);
|
||||
|
||||
// The first valid key in the range.
|
||||
const string& begin_key() const;
|
||||
// The first invalid key after the valid range.
|
||||
const string& end_key() const;
|
||||
// Returns true if the provided key is a part of the range, false otherwise.
|
||||
bool contains_key(StringPiece key) const;
|
||||
|
||||
private:
|
||||
MultiModeKeyRange(string begin, string end)
|
||||
: begin_(std::move(begin)), end_(std::move(end)) {}
|
||||
|
||||
const string begin_;
|
||||
const string end_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
|
@ -1,107 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(MultiModeKeyRangeTest, SimplePrefix) {
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix");
|
||||
EXPECT_EQ("prefix", r.begin_key());
|
||||
EXPECT_EQ("prefiy", r.end_key());
|
||||
EXPECT_TRUE(r.contains_key("prefixed_key"));
|
||||
EXPECT_FALSE(r.contains_key("not-prefixed-key"));
|
||||
EXPECT_FALSE(r.contains_key("prefi"));
|
||||
EXPECT_FALSE(r.contains_key("prefiy"));
|
||||
EXPECT_FALSE(r.contains_key("early"));
|
||||
EXPECT_FALSE(r.contains_key(""));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, Range) {
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b");
|
||||
EXPECT_EQ("a", r.begin_key());
|
||||
EXPECT_EQ("b", r.end_key());
|
||||
EXPECT_TRUE(r.contains_key("a"));
|
||||
EXPECT_TRUE(r.contains_key("ab"));
|
||||
EXPECT_FALSE(r.contains_key("b"));
|
||||
EXPECT_FALSE(r.contains_key("bc"));
|
||||
EXPECT_FALSE(r.contains_key("A"));
|
||||
EXPECT_FALSE(r.contains_key("B"));
|
||||
EXPECT_FALSE(r.contains_key(""));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, InvertedRange) {
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a");
|
||||
EXPECT_FALSE(r.contains_key("a"));
|
||||
EXPECT_FALSE(r.contains_key("b"));
|
||||
EXPECT_FALSE(r.contains_key(""));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, EmptyPrefix) {
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("");
|
||||
EXPECT_EQ("", r.begin_key());
|
||||
EXPECT_EQ("", r.end_key());
|
||||
EXPECT_TRUE(r.contains_key(""));
|
||||
EXPECT_TRUE(r.contains_key("a"));
|
||||
EXPECT_TRUE(r.contains_key("z"));
|
||||
EXPECT_TRUE(r.contains_key("A"));
|
||||
EXPECT_TRUE(r.contains_key("ZZZZZZ"));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, HalfRange) {
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", "");
|
||||
EXPECT_EQ("start", r.begin_key());
|
||||
EXPECT_EQ("", r.end_key());
|
||||
EXPECT_TRUE(r.contains_key("start"));
|
||||
EXPECT_TRUE(r.contains_key("starting"));
|
||||
EXPECT_TRUE(r.contains_key("z-end"));
|
||||
EXPECT_FALSE(r.contains_key(""));
|
||||
EXPECT_FALSE(r.contains_key("early"));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, PrefixWrapAround) {
|
||||
string prefix = "abc\xff";
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
|
||||
EXPECT_EQ(prefix, r.begin_key());
|
||||
EXPECT_EQ("abd", r.end_key());
|
||||
|
||||
EXPECT_TRUE(r.contains_key("abc\xff\x07"));
|
||||
EXPECT_TRUE(r.contains_key("abc\xff\x15"));
|
||||
EXPECT_TRUE(r.contains_key("abc\xff\x61"));
|
||||
EXPECT_TRUE(r.contains_key("abc\xff\xff"));
|
||||
EXPECT_FALSE(r.contains_key("abc\0"));
|
||||
EXPECT_FALSE(r.contains_key("abd"));
|
||||
}
|
||||
|
||||
TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) {
|
||||
string prefix = "abc\x7f";
|
||||
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
|
||||
EXPECT_EQ(prefix, r.begin_key());
|
||||
EXPECT_EQ("abc\x80", r.end_key());
|
||||
|
||||
EXPECT_TRUE(r.contains_key("abc\x7f\x07"));
|
||||
EXPECT_TRUE(r.contains_key("abc\x7f\x15"));
|
||||
EXPECT_TRUE(r.contains_key("abc\x7f\x61"));
|
||||
EXPECT_TRUE(r.contains_key("abc\x7f\xff"));
|
||||
EXPECT_FALSE(r.contains_key("abc\0"));
|
||||
EXPECT_FALSE(r.contains_key("abc\x01"));
|
||||
EXPECT_FALSE(r.contains_key("abd"));
|
||||
EXPECT_FALSE(r.contains_key("ab\x80"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,127 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
tstring start_key;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
|
||||
tstring end_key;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
|
||||
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
*output = new Dataset(ctx, resource.get(), std::move(start_key),
|
||||
std::move(end_key));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
|
||||
string start_key, string end_key)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
table_(table),
|
||||
start_key_(std::move(start_key)),
|
||||
end_key_(std::move(end_key)) {
|
||||
table_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { table_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtableRangeKeyDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: BigtableReaderDatasetIterator<Dataset>(params) {}
|
||||
|
||||
::google::cloud::bigtable::RowRange MakeRowRange() override {
|
||||
return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_,
|
||||
dataset()->end_key_);
|
||||
}
|
||||
::google::cloud::bigtable::Filter MakeFilter() override {
|
||||
return ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::CellsRowLimit(1),
|
||||
::google::cloud::bigtable::Filter::StripValueTransformer());
|
||||
}
|
||||
Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) override {
|
||||
Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
output_tensor.scalar<tstring>()() = tstring(row.row_key());
|
||||
out_tensors->emplace_back(std::move(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
BigtableTableResource* const table_;
|
||||
const string start_key_;
|
||||
const string end_key_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
|
||||
BigtableRangeKeyDatasetOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,227 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
tstring prefix;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
|
||||
|
||||
tstring start_key;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
|
||||
tstring end_key;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
|
||||
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
|
||||
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
|
||||
errors::InvalidArgument(
|
||||
"Only one of prefix and start_key can be provided"));
|
||||
if (!prefix.empty()) {
|
||||
OP_REQUIRES(ctx, end_key.empty(),
|
||||
errors::InvalidArgument(
|
||||
"If prefix is specified, end_key must be empty."));
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||
std::move(start_key), std::move(end_key));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
|
||||
string prefix, string start_key, string end_key)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
table_(table),
|
||||
key_range_(MakeMultiModeKeyRange(
|
||||
std::move(prefix), std::move(start_key), std::move(end_key))) {
|
||||
table_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { table_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||
{this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes =
|
||||
new DataTypeVector({DT_STRING, DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}, {}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtableSampleKeyPairsDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
|
||||
string start_key,
|
||||
string end_key) {
|
||||
if (!start_key.empty()) {
|
||||
return MultiModeKeyRange::FromRange(std::move(start_key),
|
||||
std::move(end_key));
|
||||
}
|
||||
return MultiModeKeyRange::FromPrefix(std::move(prefix));
|
||||
}
|
||||
|
||||
BigtableTableResource& table() const { return *table_; }
|
||||
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
// Computes split points (`keys_`) to use when scanning the table.
|
||||
//
|
||||
// Initialize first retrieves the sample keys from the table (`row_keys`),
|
||||
// as these often form good split points within the table. We then iterate
|
||||
// over them, and copy them to `keys_` if they fall within the requested
|
||||
// range to scan (`dataset()->key_range_`). Because the requested range
|
||||
// might start between elements of the sampled keys list, care is taken to
|
||||
// ensure we don't accidentally miss any subsets of the requested range by
|
||||
// including `begin_key()` and `end_key()` as appropriate.
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
::google::cloud::StatusOr<
|
||||
std::vector<::google::cloud::bigtable::RowKeySample>>
|
||||
row_key_samples = dataset()->table().table().SampleRows();
|
||||
if (!row_key_samples.ok()) {
|
||||
return GcpStatusToTfStatus(row_key_samples.status());
|
||||
}
|
||||
|
||||
for (const auto& row_key_sample : *row_key_samples) {
|
||||
string row_key(row_key_sample.row_key);
|
||||
if (dataset()->key_range_.contains_key(row_key)) {
|
||||
// First key: check to see if we need to add the begin_key.
|
||||
if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
|
||||
keys_.push_back(dataset()->key_range_.begin_key());
|
||||
}
|
||||
keys_.push_back(std::move(row_key));
|
||||
} else if (!keys_.empty()) {
|
||||
// If !keys_.empty(), then we have found at least one element of
|
||||
// `row_keys` that is within our requested range
|
||||
// (`dataset()->key_range_`). Because `row_keys` is sorted, if we
|
||||
// have found an element that's not within our key range, then we
|
||||
// are after our requested range (ranges are contiguous) and can end
|
||||
// iteration early.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the case where we skip over the selected range entirely.
|
||||
if (keys_.empty()) {
|
||||
keys_.push_back(dataset()->key_range_.begin_key());
|
||||
}
|
||||
|
||||
// Last key: check to see if we need to add the end_key.
|
||||
if (keys_.back() != dataset()->key_range_.end_key()) {
|
||||
keys_.push_back(dataset()->key_range_.end_key());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if (index_ + 2 > keys_.size()) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
*end_of_sequence = false;
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
out_tensors->back().scalar<tstring>()() = keys_[index_];
|
||||
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
out_tensors->back().scalar<tstring>()() = keys_[index_ + 1];
|
||||
++index_;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
size_t index_ GUARDED_BY(mu_) = 0;
|
||||
// Note: we store the keys_ on the iterator instead of the dataset
|
||||
// because we want to re-sample the row keys in case there have been
|
||||
// tablet rebalancing operations since the dataset was created.
|
||||
//
|
||||
// Note: keys_ is readonly after Initialize, and thus does not need a
|
||||
// guarding lock.
|
||||
std::vector<string> keys_;
|
||||
};
|
||||
|
||||
BigtableTableResource* const table_;
|
||||
const MultiModeKeyRange key_range_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
|
||||
BigtableSampleKeyPairsDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,141 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
*output = new Dataset(ctx, resource.get());
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
|
||||
: DatasetBase(DatasetContext(ctx)), table_(table) {
|
||||
table_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { table_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||
{this, strings::StrCat(prefix, "::BigtableSampleKeys")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtableRangeKeyDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
::google::cloud::StatusOr<
|
||||
std::vector<::google::cloud::bigtable::RowKeySample>>
|
||||
sampled_rows = dataset()->table()->table().SampleRows();
|
||||
if (!sampled_rows.ok()) {
|
||||
row_keys_.clear();
|
||||
return GcpStatusToTfStatus(sampled_rows.status());
|
||||
}
|
||||
row_keys_ = std::move(*sampled_rows);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if (index_ < row_keys_.size()) {
|
||||
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
|
||||
TensorShape({}));
|
||||
out_tensors->back().scalar<tstring>()() =
|
||||
tstring(row_keys_[index_].row_key);
|
||||
*end_of_sequence = false;
|
||||
index_++;
|
||||
} else {
|
||||
*end_of_sequence = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
size_t index_ = 0;
|
||||
std::vector<::google::cloud::bigtable::RowKeySample> row_keys_;
|
||||
};
|
||||
|
||||
BigtableTableResource* const table_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
|
||||
BigtableSampleKeysDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,235 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
class BigtableScanDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
tstring prefix;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
|
||||
tstring start_key;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
|
||||
tstring end_key;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
|
||||
|
||||
OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
|
||||
errors::InvalidArgument(
|
||||
"Either prefix or start_key must be specified"));
|
||||
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
|
||||
errors::InvalidArgument(
|
||||
"Only one of prefix and start_key can be provided"));
|
||||
if (!prefix.empty()) {
|
||||
OP_REQUIRES(ctx, end_key.empty(),
|
||||
errors::InvalidArgument(
|
||||
"If prefix is specified, end_key must be empty."));
|
||||
}
|
||||
|
||||
std::vector<tstring> column_families;
|
||||
std::vector<tstring> columns;
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
|
||||
&column_families));
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
|
||||
OP_REQUIRES(
|
||||
ctx, column_families.size() == columns.size(),
|
||||
errors::InvalidArgument("len(columns) != len(column_families)"));
|
||||
OP_REQUIRES(ctx, !column_families.empty(),
|
||||
errors::InvalidArgument("`column_families` is empty"));
|
||||
|
||||
float probability = 0;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument<float>(ctx, "probability", &probability));
|
||||
OP_REQUIRES(
|
||||
ctx, probability > 0 && probability <= 1,
|
||||
errors::InvalidArgument(
|
||||
"Probability outside the range of (0, 1]. Got: ", probability));
|
||||
|
||||
core::RefCountPtr<BigtableTableResource> resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
|
||||
|
||||
const uint64 num_outputs = columns.size() + 1;
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
output_shapes.reserve(num_outputs);
|
||||
DataTypeVector output_types;
|
||||
output_types.reserve(num_outputs);
|
||||
for (uint64 i = 0; i < num_outputs; ++i) {
|
||||
output_shapes.push_back({});
|
||||
output_types.push_back(DT_STRING);
|
||||
}
|
||||
|
||||
*output = new Dataset(ctx, resource.get(), std::move(prefix),
|
||||
std::move(start_key), std::move(end_key),
|
||||
std::move(column_families), std::move(columns),
|
||||
probability, output_types, std::move(output_shapes));
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
|
||||
string prefix, string start_key, string end_key,
|
||||
std::vector<tstring> column_families,
|
||||
std::vector<tstring> columns, float probability,
|
||||
const DataTypeVector& output_types,
|
||||
std::vector<PartialTensorShape> output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
table_(table),
|
||||
prefix_(std::move(prefix)),
|
||||
start_key_(std::move(start_key)),
|
||||
end_key_(std::move(end_key)),
|
||||
column_families_(std::move(column_families)),
|
||||
columns_(std::move(columns)),
|
||||
column_family_regex_(RegexFromStringSet(column_families_)),
|
||||
column_regex_(RegexFromStringSet(columns_)),
|
||||
probability_(probability),
|
||||
output_types_(output_types),
|
||||
output_shapes_(std::move(output_shapes)) {
|
||||
table_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { table_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::BigtableScan")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return output_types_;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return output_shapes_;
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "BigtableScanDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: BigtableReaderDatasetIterator<Dataset>(params) {}
|
||||
|
||||
::google::cloud::bigtable::RowRange MakeRowRange() override {
|
||||
if (!dataset()->prefix_.empty()) {
|
||||
DCHECK(dataset()->start_key_.empty());
|
||||
return ::google::cloud::bigtable::RowRange::Prefix(
|
||||
dataset()->prefix_);
|
||||
} else {
|
||||
DCHECK(!dataset()->start_key_.empty())
|
||||
<< "Both prefix and start_key were empty!";
|
||||
return ::google::cloud::bigtable::RowRange::Range(
|
||||
dataset()->start_key_, dataset()->end_key_);
|
||||
}
|
||||
}
|
||||
::google::cloud::bigtable::Filter MakeFilter() override {
|
||||
// TODO(saeta): Investigate optimal ordering here.
|
||||
return ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1),
|
||||
::google::cloud::bigtable::Filter::FamilyRegex(
|
||||
dataset()->column_family_regex_),
|
||||
::google::cloud::bigtable::Filter::ColumnRegex(
|
||||
dataset()->column_regex_),
|
||||
dataset()->probability_ != 1.0
|
||||
? ::google::cloud::bigtable::Filter::RowSample(
|
||||
dataset()->probability_)
|
||||
: ::google::cloud::bigtable::Filter::PassAllFilter());
|
||||
}
|
||||
Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) override {
|
||||
out_tensors->reserve(dataset()->columns_.size() + 1);
|
||||
Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
|
||||
out_tensors->emplace_back(std::move(row_key_tensor));
|
||||
|
||||
if (row.cells().size() > 2 * dataset()->columns_.size()) {
|
||||
LOG(WARNING) << "An excessive number of columns ("
|
||||
<< row.cells().size()
|
||||
<< ") were retrieved when reading row: "
|
||||
<< row.row_key();
|
||||
}
|
||||
|
||||
for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
|
||||
Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
bool found_column = false;
|
||||
for (auto cell_itr = row.cells().begin();
|
||||
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
|
||||
if (cell_itr->family_name() == dataset()->column_families_[i] &&
|
||||
tstring(cell_itr->column_qualifier()) ==
|
||||
dataset()->columns_[i]) {
|
||||
col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
|
||||
found_column = true;
|
||||
}
|
||||
}
|
||||
if (!found_column) {
|
||||
return errors::InvalidArgument(
|
||||
"Column ", dataset()->column_families_[i], ":",
|
||||
dataset()->columns_[i], " not found in row: ", row.row_key());
|
||||
}
|
||||
out_tensors->emplace_back(std::move(col_tensor));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
BigtableTableResource* table_;
|
||||
const string prefix_;
|
||||
const string start_key_;
|
||||
const string end_key_;
|
||||
const std::vector<tstring> column_families_;
|
||||
const std::vector<tstring> columns_;
|
||||
const string column_family_regex_;
|
||||
const string column_regex_;
|
||||
const float probability_;
|
||||
const DataTypeVector output_types_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
|
||||
BigtableScanDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -1,462 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
|
||||
|
||||
#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h"
|
||||
#include "google/protobuf/wrappers.pb.h"
|
||||
#include "re2/re2.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
// #include "util/task/codes.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void UpdateRow(const ::google::bigtable::v2::Mutation& mut,
|
||||
std::map<string, string>* row) {
|
||||
if (mut.has_set_cell()) {
|
||||
CHECK(mut.set_cell().timestamp_micros() >= -1)
|
||||
<< "Timestamp_micros: " << mut.set_cell().timestamp_micros();
|
||||
auto col =
|
||||
strings::Printf("%s:%s", mut.set_cell().family_name().c_str(),
|
||||
string(mut.set_cell().column_qualifier()).c_str());
|
||||
(*row)[col] = string(mut.set_cell().value());
|
||||
} else if (mut.has_delete_from_column()) {
|
||||
auto col = strings::Printf(
|
||||
"%s:%s", mut.delete_from_column().family_name().c_str(),
|
||||
string(mut.delete_from_column().column_qualifier()).c_str());
|
||||
row->erase(col);
|
||||
} else if (mut.has_delete_from_family()) {
|
||||
auto itr = row->lower_bound(mut.delete_from_family().family_name());
|
||||
auto prefix =
|
||||
strings::Printf("%s:", mut.delete_from_family().family_name().c_str());
|
||||
while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) {
|
||||
row->erase(itr);
|
||||
}
|
||||
} else if (mut.has_delete_from_row()) {
|
||||
row->clear();
|
||||
} else {
|
||||
LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class SampleRowKeysResponse : public grpc::ClientReaderInterface<
|
||||
google::bigtable::v2::SampleRowKeysResponse> {
|
||||
public:
|
||||
explicit SampleRowKeysResponse(BigtableTestClient* client)
|
||||
: client_(client) {}
|
||||
|
||||
bool NextMessageSize(uint32_t* sz) override {
|
||||
mutex_lock l(mu_);
|
||||
mutex_lock l2(client_->mu_);
|
||||
if (num_messages_sent_ * 2 < client_->table_.rows.size()) {
|
||||
*sz = 10000; // A sufficiently high enough value to not worry about.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override {
|
||||
// Send every other key from the table.
|
||||
mutex_lock l(mu_);
|
||||
mutex_lock l2(client_->mu_);
|
||||
*resp = google::bigtable::v2::SampleRowKeysResponse();
|
||||
auto itr = client_->table_.rows.begin();
|
||||
for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) {
|
||||
++itr;
|
||||
if (itr == client_->table_.rows.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
resp->set_row_key(itr->first);
|
||||
resp->set_offset_bytes(100 * num_messages_sent_);
|
||||
num_messages_sent_++;
|
||||
return true;
|
||||
}
|
||||
|
||||
grpc::Status Finish() override { return grpc::Status::OK; }
|
||||
|
||||
void WaitForInitialMetadata() override {} // Do nothing.
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int64 num_messages_sent_ GUARDED_BY(mu_) = 0;
|
||||
BigtableTestClient* client_; // Not owned.
|
||||
};
|
||||
|
||||
class ReadRowsResponse : public grpc::ClientReaderInterface<
|
||||
google::bigtable::v2::ReadRowsResponse> {
|
||||
public:
|
||||
ReadRowsResponse(BigtableTestClient* client,
|
||||
google::bigtable::v2::ReadRowsRequest const& request)
|
||||
: client_(client), request_(request) {}
|
||||
|
||||
bool NextMessageSize(uint32_t* sz) override {
|
||||
mutex_lock l(mu_);
|
||||
if (sent_first_message_) {
|
||||
return false;
|
||||
}
|
||||
*sz = 10000000; // A sufficiently high enough value to not worry about.
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Read(google::bigtable::v2::ReadRowsResponse* resp) override {
|
||||
mutex_lock l(mu_);
|
||||
if (sent_first_message_) {
|
||||
return false;
|
||||
}
|
||||
sent_first_message_ = true;
|
||||
RowFilter filter = MakeRowFilter();
|
||||
|
||||
mutex_lock l2(client_->mu_);
|
||||
*resp = google::bigtable::v2::ReadRowsResponse();
|
||||
// Send all contents in first response.
|
||||
for (auto itr = client_->table_.rows.begin();
|
||||
itr != client_->table_.rows.end(); ++itr) {
|
||||
if (filter.AllowRow(itr->first)) {
|
||||
::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr;
|
||||
bool sent_first = false;
|
||||
for (auto col_itr = itr->second.columns.begin();
|
||||
col_itr != itr->second.columns.end(); ++col_itr) {
|
||||
if (filter.AllowColumn(col_itr->first)) {
|
||||
chunk = resp->add_chunks();
|
||||
if (!sent_first) {
|
||||
sent_first = true;
|
||||
chunk->set_row_key(itr->first);
|
||||
}
|
||||
auto colon_idx = col_itr->first.find(":");
|
||||
CHECK(colon_idx != string::npos)
|
||||
<< "No ':' found in: " << col_itr->first;
|
||||
chunk->mutable_family_name()->set_value(
|
||||
string(col_itr->first, 0, colon_idx));
|
||||
chunk->mutable_qualifier()->set_value(
|
||||
string(col_itr->first, ++colon_idx));
|
||||
if (!filter.strip_values) {
|
||||
chunk->set_value(col_itr->second);
|
||||
}
|
||||
if (filter.only_one_column) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sent_first) {
|
||||
// We are sending this row, so set the commit flag on the last chunk.
|
||||
chunk->set_commit_row(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
grpc::Status Finish() override { return grpc::Status::OK; }
|
||||
|
||||
void WaitForInitialMetadata() override {} // Do nothing.
|
||||
|
||||
private:
|
||||
struct RowFilter {
|
||||
std::set<string> row_set;
|
||||
std::vector<std::pair<string, string>> row_ranges;
|
||||
double row_sample = 0.0; // Note: currently ignored.
|
||||
std::unique_ptr<RE2> col_filter;
|
||||
bool strip_values = false;
|
||||
bool only_one_column = false;
|
||||
|
||||
bool AllowRow(const string& row) {
|
||||
if (row_set.find(row) != row_set.end()) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& range : row_ranges) {
|
||||
if (range.first <= row && range.second > row) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool AllowColumn(const string& col) {
|
||||
if (col_filter) {
|
||||
return RE2::FullMatch(col, *col_filter);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
RowFilter MakeRowFilter() {
|
||||
RowFilter filter;
|
||||
for (auto i = request_.rows().row_keys().begin();
|
||||
i != request_.rows().row_keys().end(); ++i) {
|
||||
filter.row_set.insert(string(*i));
|
||||
}
|
||||
for (auto i = request_.rows().row_ranges().begin();
|
||||
i != request_.rows().row_ranges().end(); ++i) {
|
||||
if (i->start_key_case() !=
|
||||
google::bigtable::v2::RowRange::kStartKeyClosed ||
|
||||
i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) {
|
||||
LOG(WARNING) << "Skipping row range that cannot be processed: "
|
||||
<< i->ShortDebugString();
|
||||
continue;
|
||||
}
|
||||
filter.row_ranges.emplace_back(std::make_pair(
|
||||
string(i->start_key_closed()), string(i->end_key_open())));
|
||||
}
|
||||
if (request_.filter().has_chain()) {
|
||||
string family_filter;
|
||||
string qualifier_filter;
|
||||
for (auto i = request_.filter().chain().filters().begin();
|
||||
i != request_.filter().chain().filters().end(); ++i) {
|
||||
switch (i->filter_case()) {
|
||||
case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter:
|
||||
family_filter = i->family_name_regex_filter();
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter:
|
||||
qualifier_filter = i->column_qualifier_regex_filter();
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter:
|
||||
if (i->cells_per_column_limit_filter() != 1) {
|
||||
LOG(ERROR) << "Unexpected cells_per_column_limit_filter: "
|
||||
<< i->cells_per_column_limit_filter();
|
||||
}
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kStripValueTransformer:
|
||||
filter.strip_values = i->strip_value_transformer();
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kRowSampleFilter:
|
||||
LOG(INFO) << "Ignoring row sample directive.";
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kPassAllFilter:
|
||||
break;
|
||||
case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter:
|
||||
filter.only_one_column = true;
|
||||
break;
|
||||
default:
|
||||
LOG(WARNING) << "Ignoring unknown filter type: "
|
||||
<< i->ShortDebugString();
|
||||
}
|
||||
}
|
||||
if (family_filter.empty() || qualifier_filter.empty()) {
|
||||
LOG(WARNING) << "Missing regex!";
|
||||
} else {
|
||||
string regex = strings::Printf("%s:%s", family_filter.c_str(),
|
||||
qualifier_filter.c_str());
|
||||
filter.col_filter.reset(new RE2(regex));
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "Read request did not have a filter chain specified: "
|
||||
<< request_.filter().DebugString();
|
||||
}
|
||||
return filter;
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
bool sent_first_message_ GUARDED_BY(mu_) = false;
|
||||
BigtableTestClient* client_; // Not owned.
|
||||
const google::bigtable::v2::ReadRowsRequest request_;
|
||||
};
|
||||
|
||||
class MutateRowsResponse : public grpc::ClientReaderInterface<
|
||||
google::bigtable::v2::MutateRowsResponse> {
|
||||
public:
|
||||
explicit MutateRowsResponse(size_t num_successes)
|
||||
: num_successes_(num_successes) {}
|
||||
|
||||
bool NextMessageSize(uint32_t* sz) override {
|
||||
mutex_lock l(mu_);
|
||||
if (sent_first_message_) {
|
||||
return false;
|
||||
}
|
||||
*sz = 10000000; // A sufficiently high enough value to not worry about.
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Read(google::bigtable::v2::MutateRowsResponse* resp) override {
|
||||
mutex_lock l(mu_);
|
||||
if (sent_first_message_) {
|
||||
return false;
|
||||
}
|
||||
sent_first_message_ = true;
|
||||
*resp = google::bigtable::v2::MutateRowsResponse();
|
||||
for (size_t i = 0; i < num_successes_; ++i) {
|
||||
auto entry = resp->add_entries();
|
||||
entry->set_index(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
grpc::Status Finish() override { return grpc::Status::OK; }
|
||||
|
||||
void WaitForInitialMetadata() override {} // Do nothing.
|
||||
|
||||
private:
|
||||
const size_t num_successes_;
|
||||
|
||||
mutex mu_;
|
||||
bool sent_first_message_ = false;
|
||||
};
|
||||
|
||||
grpc::Status BigtableTestClient::MutateRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowRequest const& request,
|
||||
google::bigtable::v2::MutateRowResponse* response) {
|
||||
mutex_lock l(mu_);
|
||||
auto* row = &table_.rows[string(request.row_key())];
|
||||
for (int i = 0; i < request.mutations_size(); ++i) {
|
||||
UpdateRow(request.mutations(i), &row->columns);
|
||||
}
|
||||
*response = google::bigtable::v2::MutateRowResponse();
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
grpc::Status BigtableTestClient::CheckAndMutateRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::CheckAndMutateRowRequest const& request,
|
||||
google::bigtable::v2::CheckAndMutateRowResponse* response) {
|
||||
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
|
||||
"CheckAndMutateRow not implemented.");
|
||||
}
|
||||
grpc::Status BigtableTestClient::ReadModifyWriteRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
|
||||
google::bigtable::v2::ReadModifyWriteRowResponse* response) {
|
||||
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
|
||||
"ReadModifyWriteRow not implemented.");
|
||||
}
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::ReadModifyWriteRowResponse>>
|
||||
BigtableTestClient::AsyncReadModifyWriteRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
|
||||
grpc::CompletionQueue* cq) {
|
||||
LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
|
||||
BigtableTestClient::ReadRows(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadRowsRequest const& request) {
|
||||
return MakeUnique<ReadRowsResponse>(this, request);
|
||||
}
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
|
||||
BigtableTestClient::SampleRowKeys(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::SampleRowKeysRequest const& request) {
|
||||
return MakeUnique<SampleRowKeysResponse>(this);
|
||||
}
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
|
||||
BigtableTestClient::MutateRows(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowsRequest const& request) {
|
||||
mutex_lock l(mu_);
|
||||
for (auto i = request.entries().begin(); i != request.entries().end(); ++i) {
|
||||
auto* row = &table_.rows[string(i->row_key())];
|
||||
for (auto mut = i->mutations().begin(); mut != i->mutations().end();
|
||||
++mut) {
|
||||
UpdateRow(*mut, &row->columns);
|
||||
}
|
||||
}
|
||||
return MakeUnique<MutateRowsResponse>(request.entries_size());
|
||||
}
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::MutateRowResponse>>
|
||||
BigtableTestClient::AsyncMutateRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowRequest const& request,
|
||||
grpc::CompletionQueue* cq) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::SampleRowKeysResponse>>
|
||||
BigtableTestClient::AsyncSampleRowKeys(
|
||||
::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::SampleRowKeysRequest& request,
|
||||
::grpc::CompletionQueue* cq, void* tag) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::MutateRowsResponse>>
|
||||
BigtableTestClient::AsyncMutateRows(
|
||||
::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::MutateRowsRequest& request,
|
||||
::grpc::CompletionQueue* cq, void* tag) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::CheckAndMutateRowResponse>>
|
||||
BigtableTestClient::AsyncCheckAndMutateRow(
|
||||
grpc::ClientContext* context,
|
||||
const google::bigtable::v2::CheckAndMutateRowRequest& request,
|
||||
grpc::CompletionQueue* cq) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientAsyncReaderInterface<google::bigtable::v2::ReadRowsResponse>>
|
||||
BigtableTestClient::AsyncReadRows(
|
||||
grpc::ClientContext* context,
|
||||
const google::bigtable::v2::ReadRowsRequest& request,
|
||||
grpc::CompletionQueue* cq, void* tag) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientAsyncReaderInterface<google::bigtable::v2::MutateRowsResponse>>
|
||||
BigtableTestClient::PrepareAsyncMutateRows(
|
||||
grpc::ClientContext* context,
|
||||
const google::bigtable::v2::MutateRowsRequest& request,
|
||||
grpc::CompletionQueue* cq) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::ReadRowsResponse>>
|
||||
BigtableTestClient::PrepareAsyncReadRows(
|
||||
::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::ReadRowsRequest& request,
|
||||
::grpc::CompletionQueue* cq) {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
|
||||
<< "(); this will likely cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() {
|
||||
LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely "
|
||||
"cause a crash!";
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace tensorflow
|
@ -1,138 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
|
||||
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
|
||||
|
||||
#include "google/cloud/bigtable/data_client.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class BigtableTestClient : public ::google::cloud::bigtable::DataClient {
|
||||
public:
|
||||
std::string const& project_id() const override { return project_id_; }
|
||||
std::string const& instance_id() const override { return instance_id_; }
|
||||
void reset() override {
|
||||
mutex_lock l(mu_);
|
||||
table_ = Table();
|
||||
}
|
||||
|
||||
grpc::Status MutateRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowRequest const& request,
|
||||
google::bigtable::v2::MutateRowResponse* response) override;
|
||||
|
||||
grpc::Status CheckAndMutateRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::CheckAndMutateRowRequest const& request,
|
||||
google::bigtable::v2::CheckAndMutateRowResponse* response) override;
|
||||
|
||||
grpc::Status ReadModifyWriteRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
|
||||
google::bigtable::v2::ReadModifyWriteRowResponse* response) override;
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::ReadModifyWriteRowResponse>>
|
||||
AsyncReadModifyWriteRow(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
|
||||
grpc::CompletionQueue* cq) override;
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
|
||||
ReadRows(grpc::ClientContext* context,
|
||||
google::bigtable::v2::ReadRowsRequest const& request) override;
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
|
||||
SampleRowKeys(
|
||||
grpc::ClientContext* context,
|
||||
google::bigtable::v2::SampleRowKeysRequest const& request) override;
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
|
||||
MutateRows(grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowsRequest const& request) override;
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::MutateRowResponse>>
|
||||
AsyncMutateRow(grpc::ClientContext* context,
|
||||
google::bigtable::v2::MutateRowRequest const& request,
|
||||
grpc::CompletionQueue* cq) override;
|
||||
|
||||
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::SampleRowKeysResponse>>
|
||||
AsyncSampleRowKeys(
|
||||
::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::SampleRowKeysRequest& request,
|
||||
::grpc::CompletionQueue* cq, void* tag) override;
|
||||
|
||||
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::MutateRowsResponse>>
|
||||
AsyncMutateRows(::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::MutateRowsRequest& request,
|
||||
::grpc::CompletionQueue* cq, void* tag) override;
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
|
||||
google::bigtable::v2::CheckAndMutateRowResponse>>
|
||||
AsyncCheckAndMutateRow(
|
||||
grpc::ClientContext* context,
|
||||
const google::bigtable::v2::CheckAndMutateRowRequest& request,
|
||||
grpc::CompletionQueue* cq) override;
|
||||
|
||||
std::unique_ptr<
|
||||
grpc::ClientAsyncReaderInterface<google::bigtable::v2::ReadRowsResponse>>
|
||||
AsyncReadRows(grpc::ClientContext* context,
|
||||
const google::bigtable::v2::ReadRowsRequest& request,
|
||||
grpc::CompletionQueue* cq, void* tag) override;
|
||||
|
||||
std::unique_ptr<grpc::ClientAsyncReaderInterface<
|
||||
google::bigtable::v2::MutateRowsResponse>>
|
||||
PrepareAsyncMutateRows(grpc::ClientContext* context,
|
||||
const google::bigtable::v2::MutateRowsRequest& request,
|
||||
grpc::CompletionQueue* cq) override;
|
||||
|
||||
virtual std::unique_ptr<::grpc::ClientAsyncReaderInterface<
|
||||
::google::bigtable::v2::ReadRowsResponse>>
|
||||
PrepareAsyncReadRows(::grpc::ClientContext* context,
|
||||
const ::google::bigtable::v2::ReadRowsRequest& request,
|
||||
::grpc::CompletionQueue* cq) override;
|
||||
|
||||
std::shared_ptr<grpc::Channel> Channel() override;
|
||||
|
||||
private:
|
||||
friend class SampleRowKeysResponse;
|
||||
friend class ReadRowsResponse;
|
||||
friend class MutateRowsResponse;
|
||||
|
||||
struct Row {
|
||||
string row_key;
|
||||
std::map<string, string> columns;
|
||||
};
|
||||
struct Table {
|
||||
std::map<string, Row> rows;
|
||||
};
|
||||
|
||||
mutex mu_;
|
||||
const std::string project_id_ = "testproject";
|
||||
const std::string instance_id_ = "testinstance";
|
||||
Table table_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
|
@ -1,78 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class BigtableTestClientOp : public OpKernel {
|
||||
public:
|
||||
explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~BigtableTestClientOp() override {
|
||||
if (cinfo_.resource_is_private_to_kernel()) {
|
||||
if (!cinfo_.resource_manager()
|
||||
->Delete<BigtableClientResource>(cinfo_.container(),
|
||||
cinfo_.name())
|
||||
.ok()) {
|
||||
// Do nothing; the resource can have been deleted by session resets.
|
||||
}
|
||||
}
|
||||
}
|
||||
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
if (!initialized_) {
|
||||
ResourceMgr* mgr = ctx->resource_manager();
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
|
||||
BigtableClientResource* resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
mgr->LookupOrCreate<BigtableClientResource>(
|
||||
cinfo_.container(), cinfo_.name(), &resource,
|
||||
[this, ctx](BigtableClientResource** ret)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::shared_ptr<google::cloud::bigtable::DataClient> client(
|
||||
new BigtableTestClient());
|
||||
// Note: must make explicit copies to sequence
|
||||
// them before the move of client.
|
||||
string project_id = client->project_id();
|
||||
string instance_id = client->instance_id();
|
||||
*ret = new BigtableClientResource(std::move(project_id),
|
||||
std::move(instance_id),
|
||||
std::move(client));
|
||||
return Status::OK();
|
||||
}));
|
||||
initialized_ = true;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<BigtableClientResource>()));
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
ContainerInfo cinfo_ GUARDED_BY(mu_);
|
||||
bool initialized_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU),
|
||||
BigtableTestClientOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,349 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
|
||||
|
||||
#include "google/cloud/bigtable/table.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void WriteCell(const string& row, const string& family, const string& column,
|
||||
const string& value, ::google::cloud::bigtable::Table* table) {
|
||||
::google::cloud::bigtable::SingleRowMutation mut(row);
|
||||
mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value));
|
||||
table->Apply(std::move(mut));
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, EmptyRowRead) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
::google::cloud::bigtable::RowSet rowset;
|
||||
rowset.Append("r1");
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1));
|
||||
auto rows = table.ReadRows(std::move(rowset), filter);
|
||||
EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, SingleRowWriteAndRead) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
|
||||
::google::cloud::bigtable::RowSet rowset("r1");
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1));
|
||||
auto rows = table.ReadRows(std::move(rowset), filter);
|
||||
auto itr = rows.begin();
|
||||
EXPECT_NE(itr, rows.end()) << "No rows were returned in response!";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end());
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
|
||||
::google::cloud::bigtable::RowSet rowset("r1");
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1));
|
||||
auto rows = table.ReadRows(std::move(rowset), filter);
|
||||
auto itr = rows.begin();
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, MultiRowWriteAndRead) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
|
||||
::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3");
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1));
|
||||
auto rows = table.ReadRows(std::move(rowset), filter);
|
||||
auto itr = rows.begin();
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r2");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r3");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1));
|
||||
auto rows =
|
||||
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
|
||||
auto itr = rows.begin();
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r2");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r3");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, ColumnFiltering) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
|
||||
// Extra cells
|
||||
WriteCell("r1", "f2", "c1", "v1", &table);
|
||||
WriteCell("r2", "f2", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c2", "v3", &table);
|
||||
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1),
|
||||
::google::cloud::bigtable::Filter::FamilyRegex("f1"),
|
||||
::google::cloud::bigtable::Filter::ColumnRegex("c1"));
|
||||
auto rows =
|
||||
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
|
||||
auto itr = rows.begin();
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r2");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r3");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, RowKeys) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
|
||||
// Extra cells
|
||||
WriteCell("r1", "f2", "c1", "v1", &table);
|
||||
WriteCell("r2", "f2", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c2", "v3", &table);
|
||||
|
||||
auto filter = ::google::cloud::bigtable::Filter::Chain(
|
||||
::google::cloud::bigtable::Filter::Latest(1),
|
||||
::google::cloud::bigtable::Filter::CellsRowLimit(1),
|
||||
::google::cloud::bigtable::Filter::StripValueTransformer());
|
||||
auto rows =
|
||||
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
|
||||
auto itr = rows.begin();
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r1");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r2");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "");
|
||||
|
||||
++itr;
|
||||
|
||||
EXPECT_NE(itr, rows.end()) << "Missing rows";
|
||||
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
|
||||
EXPECT_EQ((*itr)->row_key(), "r3");
|
||||
EXPECT_EQ((*itr)->cells().size(), 1);
|
||||
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
|
||||
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
|
||||
EXPECT_EQ((*itr)->cells()[0].value(), "");
|
||||
|
||||
++itr;
|
||||
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, SampleKeys) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
WriteCell("r4", "f1", "c1", "v4", &table);
|
||||
WriteCell("r5", "f1", "c1", "v5", &table);
|
||||
|
||||
auto resp = table.SampleRows();
|
||||
EXPECT_TRUE(resp.ok());
|
||||
EXPECT_EQ(3, resp->size());
|
||||
EXPECT_EQ("r1", string((*resp)[0].row_key));
|
||||
EXPECT_EQ(0, (*resp)[0].offset_bytes);
|
||||
EXPECT_EQ("r3", string((*resp)[1].row_key));
|
||||
EXPECT_EQ(100, (*resp)[1].offset_bytes);
|
||||
EXPECT_EQ("r5", string((*resp)[2].row_key));
|
||||
EXPECT_EQ(200, (*resp)[2].offset_bytes);
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, SampleKeysShort) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
|
||||
auto resp = table.SampleRows();
|
||||
EXPECT_TRUE(resp.ok());
|
||||
EXPECT_EQ(1, resp->size());
|
||||
EXPECT_EQ("r1", string((*resp)[0].row_key));
|
||||
}
|
||||
|
||||
TEST(BigtableTestClientTest, SampleKeysEvenNumber) {
|
||||
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
|
||||
std::make_shared<BigtableTestClient>();
|
||||
::google::cloud::bigtable::Table table(client_ptr, "test_table");
|
||||
|
||||
WriteCell("r1", "f1", "c1", "v1", &table);
|
||||
WriteCell("r2", "f1", "c1", "v2", &table);
|
||||
WriteCell("r3", "f1", "c1", "v3", &table);
|
||||
WriteCell("r4", "f1", "c1", "v4", &table);
|
||||
|
||||
auto resp = table.SampleRows();
|
||||
EXPECT_TRUE(resp.ok());
|
||||
EXPECT_EQ(2, resp->size());
|
||||
EXPECT_EQ("r1", string((*resp)[0].row_key));
|
||||
EXPECT_EQ("r3", string((*resp)[1].row_key));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,107 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(saeta): Add support for setting ClientOptions values.
|
||||
REGISTER_OP("BigtableClient")
|
||||
.Attr("project_id: string")
|
||||
.Attr("instance_id: string")
|
||||
.Attr("connection_pool_size: int")
|
||||
.Attr("max_receive_message_size: int = -1")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Output("client: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// TODO(saeta): Add support for Application Profiles.
|
||||
// See https://cloud.google.com/bigtable/docs/app-profiles for more info.
|
||||
REGISTER_OP("BigtableTable")
|
||||
.Input("client: resource")
|
||||
.Attr("table_name: string")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Output("table: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("DatasetToBigtable")
|
||||
.Input("table: resource")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("column_families: string")
|
||||
.Input("columns: string")
|
||||
.Input("timestamp: int64")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("BigtableLookupDataset")
|
||||
.Input("keys_dataset: variant")
|
||||
.Input("table: resource")
|
||||
.Input("column_families: string")
|
||||
.Input("columns: string")
|
||||
.Output("handle: variant")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("BigtablePrefixKeyDataset")
|
||||
.Input("table: resource")
|
||||
.Input("prefix: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("BigtableRangeKeyDataset")
|
||||
.Input("table: resource")
|
||||
.Input("start_key: string")
|
||||
.Input("end_key: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("BigtableSampleKeysDataset")
|
||||
.Input("table: resource")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("BigtableSampleKeyPairsDataset")
|
||||
.Input("table: resource")
|
||||
.Input("prefix: string")
|
||||
.Input("start_key: string")
|
||||
.Input("end_key: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
|
||||
// skip incomplete row.)
|
||||
REGISTER_OP("BigtableScanDataset")
|
||||
.Input("table: resource")
|
||||
.Input("prefix: string")
|
||||
.Input("start_key: string")
|
||||
.Input("end_key: string")
|
||||
.Input("column_families: string")
|
||||
.Input("columns: string")
|
||||
.Input("probability: float")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
} // namespace tensorflow
|
@ -1,27 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("BigtableTestClient")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Output("client: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
} // namespace tensorflow
|
@ -1,20 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""This module contains tests for the bigtable integration."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -1,272 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Bigtable Ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import bigtable
|
||||
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
|
||||
from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
|
||||
from tensorflow.contrib.bigtable.python.ops import bigtable_api
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_bigtable_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_bigtable_test.so"))
|
||||
|
||||
|
||||
def _ListOfTuplesOfStringsToBytes(values):
|
||||
return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
|
||||
|
||||
|
||||
class BigtableOpsTest(test.TestCase):
|
||||
COMMON_ROW_KEYS = ["r1", "r2", "r3"]
|
||||
COMMON_VALUES = ["v1", "v2", "v3"]
|
||||
|
||||
def setUp(self):
|
||||
self._client = gen_bigtable_test_ops.bigtable_test_client()
|
||||
table = gen_bigtable_ops.bigtable_table(self._client, "testtable")
|
||||
self._table = bigtable.BigtableTable("testtable", None, table)
|
||||
|
||||
def _makeSimpleDataset(self):
|
||||
output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS)
|
||||
output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES)
|
||||
return dataset_ops.Dataset.zip((output_rows, output_values))
|
||||
|
||||
def _writeCommonValues(self, sess):
|
||||
output_ds = self._makeSimpleDataset()
|
||||
write_op = self._table.write(output_ds, ["cf1"], ["c1"])
|
||||
sess.run(write_op)
|
||||
|
||||
def runReadKeyTest(self, read_ds):
|
||||
itr = dataset_ops.make_initializable_iterator(read_ds)
|
||||
n = itr.get_next()
|
||||
expected = list(self.COMMON_ROW_KEYS)
|
||||
expected.reverse()
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
for i in range(3):
|
||||
output = sess.run(n)
|
||||
want = expected.pop()
|
||||
self.assertEqual(
|
||||
compat.as_bytes(want), compat.as_bytes(output),
|
||||
"Unequal at step %d: want: %s, got: %s" % (i, want, output))
|
||||
|
||||
def testReadPrefixKeys(self):
|
||||
self.runReadKeyTest(self._table.keys_by_prefix_dataset("r"))
|
||||
|
||||
def testReadRangeKeys(self):
|
||||
self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4"))
|
||||
|
||||
def runScanTest(self, read_ds):
|
||||
itr = dataset_ops.make_initializable_iterator(read_ds)
|
||||
n = itr.get_next()
|
||||
expected_keys = list(self.COMMON_ROW_KEYS)
|
||||
expected_keys.reverse()
|
||||
expected_values = list(self.COMMON_VALUES)
|
||||
expected_values.reverse()
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
for i in range(3):
|
||||
output = sess.run(n)
|
||||
want = expected_keys.pop()
|
||||
self.assertEqual(
|
||||
compat.as_bytes(want), compat.as_bytes(output[0]),
|
||||
"Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0]))
|
||||
want = expected_values.pop()
|
||||
self.assertEqual(
|
||||
compat.as_bytes(want), compat.as_bytes(output[1]),
|
||||
"Unequal values at step: %d: want: %s, got: %s" % (i, want,
|
||||
output[1]))
|
||||
|
||||
def testScanPrefixStringCol(self):
|
||||
self.runScanTest(self._table.scan_prefix("r", cf1="c1"))
|
||||
|
||||
def testScanPrefixListCol(self):
|
||||
self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
|
||||
|
||||
def testScanPrefixTupleCol(self):
|
||||
self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
|
||||
|
||||
def testScanRangeStringCol(self):
|
||||
self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
|
||||
|
||||
def testScanRangeListCol(self):
|
||||
self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
|
||||
|
||||
def testScanRangeTupleCol(self):
|
||||
self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
|
||||
|
||||
def testLookup(self):
|
||||
ds = self._table.keys_by_prefix_dataset("r")
|
||||
ds = ds.apply(self._table.lookup_columns(cf1="c1"))
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
n = itr.get_next()
|
||||
expected_keys = list(self.COMMON_ROW_KEYS)
|
||||
expected_values = list(self.COMMON_VALUES)
|
||||
expected_tuples = zip(expected_keys, expected_values)
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
for i, elem in enumerate(expected_tuples):
|
||||
output = sess.run(n)
|
||||
self.assertEqual(
|
||||
compat.as_bytes(elem[0]), compat.as_bytes(output[0]),
|
||||
"Unequal keys at step %d: want: %s, got: %s" %
|
||||
(i, compat.as_bytes(elem[0]), compat.as_bytes(output[0])))
|
||||
self.assertEqual(
|
||||
compat.as_bytes(elem[1]), compat.as_bytes(output[1]),
|
||||
"Unequal values at step %d: want: %s, got: %s" %
|
||||
(i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
|
||||
|
||||
def testSampleKeys(self):
|
||||
ds = self._table.sample_keys()
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
n = itr.get_next()
|
||||
expected_key = self.COMMON_ROW_KEYS[0]
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
output = sess.run(n)
|
||||
self.assertEqual(
|
||||
compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output),
|
||||
"Unequal keys: want: %s, got: %s" % (compat.as_bytes(
|
||||
self.COMMON_ROW_KEYS[0]), compat.as_bytes(output)))
|
||||
output = sess.run(n)
|
||||
self.assertEqual(
|
||||
compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output),
|
||||
"Unequal keys: want: %s, got: %s" % (compat.as_bytes(
|
||||
self.COMMON_ROW_KEYS[2]), compat.as_bytes(output)))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
|
||||
def runSampleKeyPairsTest(self, ds, expected_key_pairs):
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
for i, elems in enumerate(expected_key_pairs):
|
||||
output = sess.run(n)
|
||||
self.assertEqual(
|
||||
compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
|
||||
"Unequal key pair (first element) at step %d; want: %s, got %s" %
|
||||
(i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
|
||||
self.assertEqual(
|
||||
compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
|
||||
"Unequal key pair (second element) at step %d; want: %s, got %s" %
|
||||
(i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
|
||||
def testSampleKeyPairsSimplePrefix(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="r", start="", end="")
|
||||
expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairsSimpleRange(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="", start="r1", end="r3")
|
||||
expected_key_pairs = [("r1", "r3")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairsSkipRangePrefix(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="r2", start="", end="")
|
||||
expected_key_pairs = [("r2", "r3")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairsSkipRangeRange(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="", start="r2", end="r3")
|
||||
expected_key_pairs = [("r2", "r3")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairsOffsetRanges(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="", start="r2", end="r4")
|
||||
expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairEverything(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="", start="", end="")
|
||||
expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
|
||||
self.runSampleKeyPairsTest(ds, expected_key_pairs)
|
||||
|
||||
def testSampleKeyPairsPrefixAndStartKey(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="r", start="r1", end="")
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(itr.initializer)
|
||||
|
||||
def testSampleKeyPairsPrefixAndEndKey(self):
|
||||
ds = bigtable_api._BigtableSampleKeyPairsDataset(
|
||||
self._table, prefix="r", start="", end="r3")
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
sess.run(itr.initializer)
|
||||
|
||||
def testParallelScanPrefix(self):
|
||||
ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
|
||||
actual_values = []
|
||||
for _ in range(len(expected_values)):
|
||||
output = sess.run(n)
|
||||
actual_values.append(output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
self.assertItemsEqual(
|
||||
_ListOfTuplesOfStringsToBytes(expected_values),
|
||||
_ListOfTuplesOfStringsToBytes(actual_values))
|
||||
|
||||
def testParallelScanRange(self):
|
||||
ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
|
||||
itr = dataset_ops.make_initializable_iterator(ds)
|
||||
n = itr.get_next()
|
||||
with self.cached_session() as sess:
|
||||
self._writeCommonValues(sess)
|
||||
sess.run(itr.initializer)
|
||||
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
|
||||
actual_values = []
|
||||
for _ in range(len(expected_values)):
|
||||
output = sess.run(n)
|
||||
actual_values.append(output)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(n)
|
||||
self.assertItemsEqual(
|
||||
_ListOfTuplesOfStringsToBytes(expected_values),
|
||||
_ListOfTuplesOfStringsToBytes(actual_values))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,20 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""This module contains the Python API for the Cloud Bigtable integration."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -1,708 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The Python API for TensorFlow's Cloud Bigtable integration.
|
||||
|
||||
TensorFlow has support for reading from and writing to Cloud Bigtable. To use
|
||||
TensorFlow + Cloud Bigtable integration, first create a BigtableClient to
|
||||
configure your connection to Cloud Bigtable, and then create a BigtableTable
|
||||
object to allow you to create numerous `tf.data.Dataset`s to read data, or
|
||||
write a `tf.data.Dataset` object to the underlying Cloud Bigtable table.
|
||||
|
||||
For background on Cloud Bigtable, see: https://cloud.google.com/bigtable .
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six import iteritems
|
||||
from six import string_types
|
||||
|
||||
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_bigtable_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_bigtable.so"))
|
||||
|
||||
|
||||
class BigtableClient(object):
|
||||
"""BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
|
||||
|
||||
BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
|
||||
`table` method to open a Bigtable table.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_id,
|
||||
instance_id,
|
||||
connection_pool_size=None,
|
||||
max_receive_message_size=None):
|
||||
"""Creates a BigtableClient that can be used to open connections to tables.
|
||||
|
||||
Args:
|
||||
project_id: A string representing the GCP project id to connect to.
|
||||
instance_id: A string representing the Bigtable instance to connect to.
|
||||
connection_pool_size: (Optional.) A number representing the number of
|
||||
concurrent connections to the Cloud Bigtable service to make.
|
||||
max_receive_message_size: (Optional.) The maximum bytes received in a
|
||||
single gRPC response.
|
||||
|
||||
Raises:
|
||||
ValueError: if the arguments are invalid (e.g. wrong type, or out of
|
||||
expected ranges (e.g. negative).)
|
||||
"""
|
||||
if not isinstance(project_id, str):
|
||||
raise ValueError("`project_id` must be a string")
|
||||
self._project_id = project_id
|
||||
|
||||
if not isinstance(instance_id, str):
|
||||
raise ValueError("`instance_id` must be a string")
|
||||
self._instance_id = instance_id
|
||||
|
||||
if connection_pool_size is None:
|
||||
connection_pool_size = -1
|
||||
elif connection_pool_size < 1:
|
||||
raise ValueError("`connection_pool_size` must be positive")
|
||||
|
||||
if max_receive_message_size is None:
|
||||
max_receive_message_size = -1
|
||||
elif max_receive_message_size < 1:
|
||||
raise ValueError("`max_receive_message_size` must be positive")
|
||||
|
||||
self._connection_pool_size = connection_pool_size
|
||||
|
||||
self._resource = gen_bigtable_ops.bigtable_client(
|
||||
project_id, instance_id, connection_pool_size, max_receive_message_size)
|
||||
|
||||
def table(self, name, snapshot=None):
|
||||
"""Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object.
|
||||
|
||||
Args:
|
||||
name: A `tf.string` `tf.Tensor` name of the table to open.
|
||||
snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
|
||||
request the creation of a snapshot. (Note: currently unimplemented.)
|
||||
|
||||
Returns:
|
||||
A `tf.contrib.bigtable.BigtableTable` Python object representing the
|
||||
operations available on the table.
|
||||
"""
|
||||
# TODO(saeta): Implement snapshot functionality.
|
||||
table = gen_bigtable_ops.bigtable_table(self._resource, name)
|
||||
return BigtableTable(name, snapshot, table)
|
||||
|
||||
|
||||
class BigtableTable(object):
|
||||
"""Entry point for reading and writing data in Cloud Bigtable.
|
||||
|
||||
This BigtableTable class is the Python representation of the Cloud Bigtable
|
||||
table within TensorFlow. Methods on this class allow data to be read from and
|
||||
written to the Cloud Bigtable service in flexible and high performance
|
||||
manners.
|
||||
"""
|
||||
|
||||
# TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
|
||||
# TODO(saeta): Consider variant tensors instead of resources (while supporting
|
||||
# connection pooling).
|
||||
|
||||
def __init__(self, name, snapshot, resource):
|
||||
self._name = name
|
||||
self._snapshot = snapshot
|
||||
self._resource = resource
|
||||
|
||||
def lookup_columns(self, *args, **kwargs):
|
||||
"""Retrieves the values of columns for a dataset of keys.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
table = bigtable_client.table("my_table")
|
||||
key_dataset = table.get_keys_prefix("imagenet")
|
||||
images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
|
||||
("cf2", "label"),
|
||||
("cf2", "boundingbox")))
|
||||
training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
|
||||
```
|
||||
|
||||
Alternatively, you can use keyword arguments to specify the columns to
|
||||
capture. Example (same as above, rewritten):
|
||||
|
||||
```python
|
||||
table = bigtable_client.table("my_table")
|
||||
key_dataset = table.get_keys_prefix("imagenet")
|
||||
images = key_dataset.apply(table.lookup_columns(
|
||||
cf1="image", cf2=("label", "boundingbox")))
|
||||
training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
|
||||
```
|
||||
|
||||
Note: certain `kwargs` keys are reserved, and thus, some column families
|
||||
cannot be identified using the `kwargs` syntax. Instead, please use the
|
||||
`args` syntax. This list includes:
|
||||
|
||||
- 'name'
|
||||
|
||||
Note: this list can change at any time.
|
||||
|
||||
Args:
|
||||
*args: A list of tuples containing (column family, column name) pairs.
|
||||
**kwargs: Column families (keys) and column qualifiers (values).
|
||||
|
||||
Returns:
|
||||
A function that can be passed to `tf.data.Dataset.apply` to retrieve the
|
||||
values of columns for the rows.
|
||||
"""
|
||||
table = self # Capture self
|
||||
normalized = args
|
||||
if normalized is None:
|
||||
normalized = []
|
||||
if isinstance(normalized, tuple):
|
||||
normalized = list(normalized)
|
||||
for key, value in iteritems(kwargs):
|
||||
if key == "name":
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
normalized.append((key, value))
|
||||
continue
|
||||
for col in value:
|
||||
normalized.append((key, col))
|
||||
|
||||
def _apply_fn(dataset):
|
||||
# TODO(saeta): Verify dataset's types are correct!
|
||||
return _BigtableLookupDataset(dataset, table, normalized)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
def keys_by_range_dataset(self, start, end):
|
||||
"""Retrieves all row keys between start and end.
|
||||
|
||||
Note: it does NOT retrieve the values of columns.
|
||||
|
||||
Args:
|
||||
start: The start row key. The row keys for rows after start (inclusive)
|
||||
will be retrieved.
|
||||
end: (Optional.) The end row key. Rows up to (but not including) end will
|
||||
be retrieved. If end is None, all subsequent row keys will be retrieved.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all
|
||||
of the row keys between `start` and `end`.
|
||||
"""
|
||||
# TODO(saeta): Make inclusive / exclusive configurable?
|
||||
if end is None:
|
||||
end = ""
|
||||
return _BigtableRangeKeyDataset(self, start, end)
|
||||
|
||||
def keys_by_prefix_dataset(self, prefix):
|
||||
"""Retrieves the row keys matching a given prefix.
|
||||
|
||||
Args:
|
||||
prefix: All row keys that begin with `prefix` in the table will be
|
||||
retrieved.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all
|
||||
of the row keys matching that prefix.
|
||||
"""
|
||||
return _BigtablePrefixKeyDataset(self, prefix)
|
||||
|
||||
def sample_keys(self):
|
||||
"""Retrieves a sampling of row keys from the Bigtable table.
|
||||
|
||||
This dataset is most often used in conjunction with
|
||||
`tf.data.experimental.parallel_interleave` to construct a set of ranges for
|
||||
scanning in parallel.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` returning string row keys.
|
||||
"""
|
||||
return _BigtableSampleKeysDataset(self)
|
||||
|
||||
def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
|
||||
"""Retrieves row (including values) from the Bigtable service.
|
||||
|
||||
Rows with row-key prefixed by `prefix` will be retrieved.
|
||||
|
||||
Specifying the columns to retrieve for each row is done by either using
|
||||
kwargs or in the columns parameter. To retrieve values of the columns "c1",
|
||||
and "c2" from the column family "cfa", and the value of the column "c3"
|
||||
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
|
||||
equivalent:
|
||||
|
||||
```
|
||||
table = # ...
|
||||
ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
|
||||
("cfa", "c2"),
|
||||
("cfb", "c3")])
|
||||
ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
|
||||
```
|
||||
|
||||
Note: only the latest value of a cell will be retrieved.
|
||||
|
||||
Args:
|
||||
prefix: The prefix all row keys must match to be retrieved for prefix-
|
||||
based scans.
|
||||
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
|
||||
A non-1 value indicates to probabilistically sample rows with the
|
||||
provided probability.
|
||||
columns: The columns to read. Note: most commonly, they are expressed as
|
||||
kwargs. Use the columns value if you are using column families that are
|
||||
reserved. The value of columns and kwargs are merged. Columns is a list
|
||||
of tuples of strings ("column_family", "column_qualifier").
|
||||
**kwargs: The column families and columns to read. Keys are treated as
|
||||
column_families, and values can be either lists of strings, or strings
|
||||
that are treated as the column qualifier (column name).
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` returning the row keys and the cell contents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configured probability is unexpected.
|
||||
"""
|
||||
probability = _normalize_probability(probability)
|
||||
normalized = _normalize_columns(columns, kwargs)
|
||||
return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
|
||||
|
||||
def scan_range(self, start, end, probability=None, columns=None, **kwargs):
|
||||
"""Retrieves rows (including values) from the Bigtable service.
|
||||
|
||||
Rows with row-keys between `start` and `end` will be retrieved.
|
||||
|
||||
Specifying the columns to retrieve for each row is done by either using
|
||||
kwargs or in the columns parameter. To retrieve values of the columns "c1",
|
||||
and "c2" from the column family "cfa", and the value of the column "c3"
|
||||
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
|
||||
equivalent:
|
||||
|
||||
```
|
||||
table = # ...
|
||||
ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
|
||||
("cfa", "c2"),
|
||||
("cfb", "c3")])
|
||||
ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
|
||||
```
|
||||
|
||||
Note: only the latest value of a cell will be retrieved.
|
||||
|
||||
Args:
|
||||
start: The start of the range when scanning by range.
|
||||
end: (Optional.) The end of the range when scanning by range.
|
||||
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
|
||||
A non-1 value indicates to probabilistically sample rows with the
|
||||
provided probability.
|
||||
columns: The columns to read. Note: most commonly, they are expressed as
|
||||
kwargs. Use the columns value if you are using column families that are
|
||||
reserved. The value of columns and kwargs are merged. Columns is a list
|
||||
of tuples of strings ("column_family", "column_qualifier").
|
||||
**kwargs: The column families and columns to read. Keys are treated as
|
||||
column_families, and values can be either lists of strings, or strings
|
||||
that are treated as the column qualifier (column name).
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` returning the row keys and the cell contents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configured probability is unexpected.
|
||||
"""
|
||||
probability = _normalize_probability(probability)
|
||||
normalized = _normalize_columns(columns, kwargs)
|
||||
return _BigtableScanDataset(self, "", start, end, normalized, probability)
|
||||
|
||||
def parallel_scan_prefix(self,
|
||||
prefix,
|
||||
num_parallel_scans=None,
|
||||
probability=None,
|
||||
columns=None,
|
||||
**kwargs):
|
||||
"""Retrieves row (including values) from the Bigtable service at high speed.
|
||||
|
||||
Rows with row-key prefixed by `prefix` will be retrieved. This method is
|
||||
similar to `scan_prefix`, but by contrast performs multiple sub-scans in
|
||||
parallel in order to achieve higher performance.
|
||||
|
||||
Note: The dataset produced by this method is not deterministic!
|
||||
|
||||
Specifying the columns to retrieve for each row is done by either using
|
||||
kwargs or in the columns parameter. To retrieve values of the columns "c1",
|
||||
and "c2" from the column family "cfa", and the value of the column "c3"
|
||||
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
|
||||
equivalent:
|
||||
|
||||
```
|
||||
table = # ...
|
||||
ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
|
||||
("cfa", "c2"),
|
||||
("cfb", "c3")])
|
||||
ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
|
||||
```
|
||||
|
||||
Note: only the latest value of a cell will be retrieved.
|
||||
|
||||
Args:
|
||||
prefix: The prefix all row keys must match to be retrieved for prefix-
|
||||
based scans.
|
||||
num_parallel_scans: (Optional.) The number of concurrent scans against the
|
||||
Cloud Bigtable instance.
|
||||
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
|
||||
A non-1 value indicates to probabilistically sample rows with the
|
||||
provided probability.
|
||||
columns: The columns to read. Note: most commonly, they are expressed as
|
||||
kwargs. Use the columns value if you are using column families that are
|
||||
reserved. The value of columns and kwargs are merged. Columns is a list
|
||||
of tuples of strings ("column_family", "column_qualifier").
|
||||
**kwargs: The column families and columns to read. Keys are treated as
|
||||
column_families, and values can be either lists of strings, or strings
|
||||
that are treated as the column qualifier (column name).
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` returning the row keys and the cell contents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configured probability is unexpected.
|
||||
"""
|
||||
probability = _normalize_probability(probability)
|
||||
normalized = _normalize_columns(columns, kwargs)
|
||||
ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "")
|
||||
return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
|
||||
normalized)
|
||||
|
||||
def parallel_scan_range(self,
|
||||
start,
|
||||
end,
|
||||
num_parallel_scans=None,
|
||||
probability=None,
|
||||
columns=None,
|
||||
**kwargs):
|
||||
"""Retrieves rows (including values) from the Bigtable service.
|
||||
|
||||
Rows with row-keys between `start` and `end` will be retrieved. This method
|
||||
is similar to `scan_range`, but by contrast performs multiple sub-scans in
|
||||
parallel in order to achieve higher performance.
|
||||
|
||||
Note: The dataset produced by this method is not deterministic!
|
||||
|
||||
Specifying the columns to retrieve for each row is done by either using
|
||||
kwargs or in the columns parameter. To retrieve values of the columns "c1",
|
||||
and "c2" from the column family "cfa", and the value of the column "c3"
|
||||
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
|
||||
equivalent:
|
||||
|
||||
```
|
||||
table = # ...
|
||||
ds1 = table.parallel_scan_range("row_start",
|
||||
"row_end",
|
||||
columns=[("cfa", "c1"),
|
||||
("cfa", "c2"),
|
||||
("cfb", "c3")])
|
||||
ds2 = table.parallel_scan_range("row_start", "row_end",
|
||||
cfa=["c1", "c2"], cfb="c3")
|
||||
```
|
||||
|
||||
Note: only the latest value of a cell will be retrieved.
|
||||
|
||||
Args:
|
||||
start: The start of the range when scanning by range.
|
||||
end: (Optional.) The end of the range when scanning by range.
|
||||
num_parallel_scans: (Optional.) The number of concurrent scans against the
|
||||
Cloud Bigtable instance.
|
||||
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
|
||||
A non-1 value indicates to probabilistically sample rows with the
|
||||
provided probability.
|
||||
columns: The columns to read. Note: most commonly, they are expressed as
|
||||
kwargs. Use the columns value if you are using column families that are
|
||||
reserved. The value of columns and kwargs are merged. Columns is a list
|
||||
of tuples of strings ("column_family", "column_qualifier").
|
||||
**kwargs: The column families and columns to read. Keys are treated as
|
||||
column_families, and values can be either lists of strings, or strings
|
||||
that are treated as the column qualifier (column name).
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` returning the row keys and the cell contents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configured probability is unexpected.
|
||||
"""
|
||||
probability = _normalize_probability(probability)
|
||||
normalized = _normalize_columns(columns, kwargs)
|
||||
ds = _BigtableSampleKeyPairsDataset(self, "", start, end)
|
||||
return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
|
||||
normalized)
|
||||
|
||||
def write(self, dataset, column_families, columns, timestamp=None):
|
||||
"""Writes a dataset to the table.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` to be written to this table. It must produce
|
||||
a list of number-of-columns+1 elements, all of which must be strings.
|
||||
The first value will be used as the row key, and subsequent values will
|
||||
be used as cell values for the corresponding columns from the
|
||||
corresponding column_families and columns entries.
|
||||
column_families: A `tf.Tensor` of `tf.string`s corresponding to the
|
||||
column names to store the dataset's elements into.
|
||||
columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
|
||||
to store the dataset's elements into.
|
||||
timestamp: (Optional.) An int64 timestamp to write all the values at.
|
||||
Leave as None to use server-provided timestamps.
|
||||
|
||||
Returns:
|
||||
A `tf.Operation` that can be run to perform the write.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are unexpected or incompatible types, or if the
|
||||
number of columns and column_families does not match the output of
|
||||
`dataset`.
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = -1 # Bigtable server provided timestamp.
|
||||
for tensor_type in nest.flatten(
|
||||
dataset_ops.get_legacy_output_types(dataset)):
|
||||
if tensor_type != dtypes.string:
|
||||
raise ValueError("Not all elements of the dataset were `tf.string`")
|
||||
for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
|
||||
if not shape.is_compatible_with(tensor_shape.TensorShape([])):
|
||||
raise ValueError("Not all elements of the dataset were scalars")
|
||||
if len(column_families) != len(columns):
|
||||
raise ValueError("len(column_families) != len(columns)")
|
||||
if len(nest.flatten(
|
||||
dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1:
|
||||
raise ValueError("A column name must be specified for every component of "
|
||||
"the dataset elements. (e.g.: len(columns) != "
|
||||
"len(dataset.output_types))")
|
||||
return gen_bigtable_ops.dataset_to_bigtable(
|
||||
self._resource,
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
column_families,
|
||||
columns,
|
||||
timestamp)
|
||||
|
||||
def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
|
||||
normalized_probability, normalized_columns):
|
||||
"""Builds a parallel dataset from a given range.
|
||||
|
||||
Args:
|
||||
ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
|
||||
num_parallel_scans: The number of concurrent parallel scans to use.
|
||||
normalized_probability: A number between 0 and 1 for the keep probability.
|
||||
normalized_columns: The column families and column qualifiers to retrieve.
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset` representing the result of the parallel scan.
|
||||
"""
|
||||
if num_parallel_scans is None:
|
||||
num_parallel_scans = 50
|
||||
|
||||
ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
|
||||
|
||||
def _interleave_fn(start, end):
|
||||
return _BigtableScanDataset(
|
||||
self,
|
||||
prefix="",
|
||||
start=start,
|
||||
end=end,
|
||||
normalized=normalized_columns,
|
||||
probability=normalized_probability)
|
||||
|
||||
# Note prefetch_input_elements must be set in order to avoid rpc timeouts.
|
||||
ds = ds.apply(
|
||||
interleave_ops.parallel_interleave(
|
||||
_interleave_fn,
|
||||
cycle_length=num_parallel_scans,
|
||||
sloppy=True,
|
||||
prefetch_input_elements=1))
|
||||
return ds
|
||||
|
||||
|
||||
def _normalize_probability(probability):
|
||||
if probability is None:
|
||||
probability = 1.0
|
||||
if isinstance(probability, float) and (probability <= 0.0 or
|
||||
probability > 1.0):
|
||||
raise ValueError("probability must be in the range (0, 1].")
|
||||
return probability
|
||||
|
||||
|
||||
def _normalize_columns(columns, provided_kwargs):
|
||||
"""Converts arguments (columns, and kwargs dict) to C++ representation.
|
||||
|
||||
Args:
|
||||
columns: a datastructure containing the column families and qualifier to
|
||||
retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
|
||||
strings.
|
||||
provided_kwargs: a dictionary containing the column families and qualifiers
|
||||
to retrieve
|
||||
|
||||
Returns:
|
||||
A list of pairs of column family+qualifier to retrieve.
|
||||
|
||||
Raises:
|
||||
ValueError: If there are no cells to retrieve or the columns are in an
|
||||
incorrect format.
|
||||
"""
|
||||
normalized = columns
|
||||
if normalized is None:
|
||||
normalized = []
|
||||
if isinstance(normalized, tuple):
|
||||
if len(normalized) == 2:
|
||||
normalized = [normalized]
|
||||
else:
|
||||
raise ValueError("columns was a tuple of inappropriate length")
|
||||
for key, value in iteritems(provided_kwargs):
|
||||
if key == "name":
|
||||
continue
|
||||
if isinstance(value, string_types):
|
||||
normalized.append((key, value))
|
||||
continue
|
||||
for col in value:
|
||||
normalized.append((key, col))
|
||||
if not normalized:
|
||||
raise ValueError("At least one column + column family must be specified.")
|
||||
return normalized
|
||||
|
||||
|
||||
class _BigtableKeyDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
|
||||
"""
|
||||
|
||||
def __init__(self, table, variant_tensor):
|
||||
"""Constructs a _BigtableKeyDataset.
|
||||
|
||||
Args:
|
||||
table: a Bigtable class.
|
||||
variant_tensor: DT_VARIANT representation of the dataset.
|
||||
"""
|
||||
super(_BigtableKeyDataset, self).__init__(variant_tensor)
|
||||
self._table = table
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return tensor_spec.TensorSpec([], dtypes.string)
|
||||
|
||||
|
||||
class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
|
||||
"""_BigtablePrefixKeyDataset represents looking up keys by prefix.
|
||||
"""
|
||||
|
||||
def __init__(self, table, prefix):
|
||||
self._prefix = prefix
|
||||
variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
|
||||
table=table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix)
|
||||
super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableRangeKeyDataset(_BigtableKeyDataset):
|
||||
"""_BigtableRangeKeyDataset represents looking up keys by range.
|
||||
"""
|
||||
|
||||
def __init__(self, table, start, end):
|
||||
self._start = start
|
||||
self._end = end
|
||||
variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
|
||||
table=table._resource, # pylint: disable=protected-access
|
||||
start_key=self._start,
|
||||
end_key=self._end)
|
||||
super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableSampleKeysDataset(_BigtableKeyDataset):
|
||||
"""_BigtableSampleKeysDataset represents a sampling of row keys.
|
||||
"""
|
||||
|
||||
# TODO(saeta): Expose the data size offsets into the keys.
|
||||
|
||||
def __init__(self, table):
|
||||
variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
|
||||
table=table._resource) # pylint: disable=protected-access
|
||||
super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
|
||||
|
||||
|
||||
class _BigtableLookupDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, table, normalized):
|
||||
self._num_outputs = len(normalized) + 1 # 1 for row key
|
||||
self._dataset = dataset
|
||||
self._table = table
|
||||
self._normalized = normalized
|
||||
self._column_families = [i[0] for i in normalized]
|
||||
self._columns = [i[1] for i in normalized]
|
||||
variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
|
||||
keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
column_families=self._column_families,
|
||||
columns=self._columns)
|
||||
super(_BigtableLookupDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return tuple([tensor_spec.TensorSpec([], dtypes.string)] *
|
||||
self._num_outputs)
|
||||
|
||||
|
||||
class _BigtableScanDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
|
||||
"""
|
||||
|
||||
def __init__(self, table, prefix, start, end, normalized, probability):
|
||||
self._table = table
|
||||
self._prefix = prefix
|
||||
self._start = start
|
||||
self._end = end
|
||||
self._column_families = [i[0] for i in normalized]
|
||||
self._columns = [i[1] for i in normalized]
|
||||
self._probability = probability
|
||||
self._num_outputs = len(normalized) + 1 # 1 for row key
|
||||
variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix,
|
||||
start_key=self._start,
|
||||
end_key=self._end,
|
||||
column_families=self._column_families,
|
||||
columns=self._columns,
|
||||
probability=self._probability)
|
||||
super(_BigtableScanDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return tuple([tensor_spec.TensorSpec([], dtypes.string)] *
|
||||
self._num_outputs)
|
||||
|
||||
|
||||
class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
|
||||
"""_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
|
||||
"""
|
||||
|
||||
def __init__(self, table, prefix, start, end):
|
||||
self._table = table
|
||||
self._prefix = prefix
|
||||
self._start = start
|
||||
self._end = end
|
||||
variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
|
||||
table=self._table._resource, # pylint: disable=protected-access
|
||||
prefix=self._prefix,
|
||||
start_key=self._start,
|
||||
end_key=self._end)
|
||||
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return (tensor_spec.TensorSpec([], dtypes.string),
|
||||
tensor_spec.TensorSpec([], dtypes.string))
|
@ -1,614 +0,0 @@
|
||||
# TensorFlow code for training gradient boosted trees.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package_group(name = "friends")
|
||||
|
||||
cc_library(
|
||||
name = "boosted_trees_kernels",
|
||||
deps = [
|
||||
":model_ops_kernels",
|
||||
":prediction_ops_kernels",
|
||||
":quantile_ops_kernels",
|
||||
":split_handler_ops_kernels",
|
||||
":stats_accumulator_ops_kernels",
|
||||
":training_ops_kernels",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "boosted_trees_ops_op_lib",
|
||||
deps = [
|
||||
":model_ops_op_lib",
|
||||
":prediction_ops_op_lib",
|
||||
":quantile_ops_op_lib",
|
||||
":split_handler_ops_op_lib",
|
||||
":stats_accumulator_ops_op_lib",
|
||||
":training_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "init_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":boosted_trees_ops_py",
|
||||
":losses",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "losses",
|
||||
srcs = ["python/utils/losses.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "losses_test",
|
||||
size = "small",
|
||||
srcs = ["python/utils/losses_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":losses",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gbdt_batch",
|
||||
srcs = [
|
||||
"python/training/functions/gbdt_batch.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_model_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees:batch_ops_utils_py",
|
||||
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/lib:categorical_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:stateless_random_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/feature_column",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gbdt_batch_test",
|
||||
size = "medium",
|
||||
srcs = ["python/training/functions/gbdt_batch_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"nofwdcompat", # b/137641346
|
||||
"notsan", # b/62863147
|
||||
],
|
||||
deps = [
|
||||
":gbdt_batch",
|
||||
":losses",
|
||||
":model_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
# Kernel tests
|
||||
|
||||
py_test(
|
||||
name = "model_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/model_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_ops_py",
|
||||
":prediction_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "prediction_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/prediction_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_ops_py",
|
||||
":prediction_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "quantile_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/quantile_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
shard_count = 3,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":quantile_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:training",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "split_handler_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/split_handler_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":split_handler_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "stats_accumulator_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/stats_accumulator_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":stats_accumulator_ops_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_ops_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/training_ops_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_ops_py",
|
||||
":training_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
# Ops
|
||||
|
||||
py_library(
|
||||
name = "batch_ops_utils_py",
|
||||
srcs = ["python/ops/batch_ops_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "boosted_trees_ops_loader",
|
||||
srcs = ["python/ops/boosted_trees_ops_loader.py"],
|
||||
dso = [":python/ops/_boosted_trees_ops.so"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "boosted_trees_ops_py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":model_ops_py",
|
||||
":prediction_ops_py",
|
||||
":quantile_ops_py",
|
||||
":split_handler_ops_py",
|
||||
":stats_accumulator_ops_py",
|
||||
":training_ops_py",
|
||||
],
|
||||
)
|
||||
|
||||
# Model Ops.
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["model_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_model_ops_py",
|
||||
out = "python/ops/gen_model_ops.py",
|
||||
deps = [":model_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "model_ops_py",
|
||||
srcs = ["python/ops/model_ops.py"],
|
||||
kernels = [
|
||||
":model_ops_kernels",
|
||||
":model_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_model_ops_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "model_ops_kernels",
|
||||
srcs = ["kernels/model_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "python/ops/_boosted_trees_ops.so",
|
||||
srcs = [
|
||||
"kernels/model_ops.cc",
|
||||
"kernels/prediction_ops.cc",
|
||||
"kernels/quantile_ops.cc",
|
||||
"kernels/split_handler_ops.cc",
|
||||
"kernels/stats_accumulator_ops.cc",
|
||||
"kernels/training_ops.cc",
|
||||
"ops/model_ops.cc",
|
||||
"ops/prediction_ops.cc",
|
||||
"ops/quantile_ops.cc",
|
||||
"ops/split_handler_ops.cc",
|
||||
"ops/stats_accumulator_ops.cc",
|
||||
"ops/training_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:example_partitioner",
|
||||
"//tensorflow/contrib/boosted_trees/lib:models",
|
||||
"//tensorflow/contrib/boosted_trees/lib:node-stats",
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
|
||||
"//tensorflow/contrib/boosted_trees/resources:stamped_resource",
|
||||
],
|
||||
)
|
||||
|
||||
# Split handler Ops.
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["split_handler_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_split_handler_ops_py",
|
||||
out = "python/ops/gen_split_handler_ops.py",
|
||||
deps = [
|
||||
":split_handler_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "split_handler_ops_py",
|
||||
srcs = ["python/ops/split_handler_ops.py"],
|
||||
kernels = [
|
||||
":split_handler_ops_kernels",
|
||||
":split_handler_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_split_handler_ops_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "split_handler_ops_kernels",
|
||||
srcs = ["kernels/split_handler_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:node-stats",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Training Ops.
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"training_ops",
|
||||
],
|
||||
deps = ["//tensorflow/contrib/boosted_trees/proto:learner_proto_cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_training_ops_py",
|
||||
out = "python/ops/gen_training_ops.py",
|
||||
deps = [
|
||||
":training_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "training_ops_py",
|
||||
srcs = ["python/ops/training_ops.py"],
|
||||
kernels = [
|
||||
":training_ops_kernels",
|
||||
":training_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_training_ops_py",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "training_ops_kernels",
|
||||
srcs = ["kernels/training_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Prediction Ops.
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["prediction_ops"],
|
||||
deps = ["//tensorflow/contrib/boosted_trees/proto:learner_proto_cc"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_prediction_ops_py",
|
||||
out = "python/ops/gen_prediction_ops.py",
|
||||
deps = [
|
||||
":prediction_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "prediction_ops_py",
|
||||
srcs = ["python/ops/prediction_ops.py"],
|
||||
kernels = [
|
||||
":prediction_ops_kernels",
|
||||
":prediction_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_prediction_ops_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "prediction_ops_kernels",
|
||||
srcs = ["kernels/prediction_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:example_partitioner",
|
||||
"//tensorflow/contrib/boosted_trees/lib:models",
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Quantile ops
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["quantile_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_quantile_ops_py_wrap",
|
||||
out = "python/ops/gen_quantile_ops.py",
|
||||
deps = [
|
||||
":quantile_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "quantile_ops_py",
|
||||
srcs = ["python/ops/quantile_ops.py"],
|
||||
kernels = [
|
||||
":quantile_ops_kernels",
|
||||
":quantile_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":batch_ops_utils_py",
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_quantile_ops_py_wrap",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "quantile_ops_kernels",
|
||||
srcs = ["kernels/quantile_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Stats Accumulator ops
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["stats_accumulator_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_stats_accumulator_ops_py_wrap",
|
||||
out = "python/ops/gen_stats_accumulator_ops.py",
|
||||
deps = [
|
||||
":stats_accumulator_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "stats_accumulator_ops_py",
|
||||
srcs = ["python/ops/stats_accumulator_ops.py"],
|
||||
kernels = [
|
||||
":stats_accumulator_ops_kernels",
|
||||
":stats_accumulator_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":batch_ops_utils_py",
|
||||
":boosted_trees_ops_loader",
|
||||
":gen_stats_accumulator_ops_py_wrap",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stats_accumulator_ops_kernels",
|
||||
srcs = ["kernels/stats_accumulator_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/resources:stamped_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Pip
|
||||
|
||||
py_library(
|
||||
name = "boosted_trees_pip",
|
||||
deps = [
|
||||
":init_py",
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy",
|
||||
"//tensorflow/contrib/boosted_trees/estimator_batch:dnn_tree_combined_estimator",
|
||||
"//tensorflow/contrib/boosted_trees/estimator_batch:init_py",
|
||||
"//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks",
|
||||
"//tensorflow/contrib/boosted_trees/lib:categorical_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
],
|
||||
)
|
@ -1,11 +0,0 @@
|
||||
# TF Boosted Trees (TFBT)
|
||||
|
||||
TF Boosted trees is an implementation of a gradient boosting algorithm with
|
||||
trees used as weak learners.
|
||||
|
||||
## Examples
|
||||
Folder "examples" demonstrates how TFBT estimators can be used for various
|
||||
problems. Namely, it contains:
|
||||
* binary_mnist.py - an example on how to use TFBT for binary classification.
|
||||
* mnist.py - a multiclass example.
|
||||
* boston.py - a regression example.
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Gradient boosted trees implementation in tensorflow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.boosted_trees.python import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
@ -1,220 +0,0 @@
|
||||
# This directory contains estimators to train and run inference on
|
||||
# gradient boosted trees on top of TensorFlow.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "init_py",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":custom_export_strategy",
|
||||
":custom_loss_head",
|
||||
":distillation_loss",
|
||||
":estimator",
|
||||
":model",
|
||||
":trainer_hooks",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model",
|
||||
srcs = ["model.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":estimator_utils",
|
||||
":trainer_hooks",
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/boosted_trees:model_ops_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "trainer_hooks",
|
||||
srcs = ["trainer_hooks.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator_utils",
|
||||
srcs = ["estimator_utils.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "trainer_hooks_test",
|
||||
size = "small",
|
||||
srcs = ["trainer_hooks_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
# TODO(b/134605801): Re-enable this test.
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":trainer_hooks",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "custom_loss_head",
|
||||
srcs = ["custom_loss_head.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "custom_export_strategy",
|
||||
srcs = ["custom_export_strategy.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py",
|
||||
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"//tensorflow/python/saved_model:tag_constants",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "custom_export_strategy_test",
|
||||
size = "small",
|
||||
srcs = ["custom_export_strategy_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":custom_export_strategy",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator",
|
||||
srcs = ["estimator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":custom_loss_head",
|
||||
":estimator_utils",
|
||||
":model",
|
||||
"//tensorflow/contrib/boosted_trees:losses",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dnn_tree_combined_estimator",
|
||||
srcs = ["dnn_tree_combined_estimator.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":distillation_loss",
|
||||
":estimator_utils",
|
||||
":model",
|
||||
":trainer_hooks",
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/boosted_trees:model_ops_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distillation_loss",
|
||||
srcs = ["distillation_loss.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dnn_tree_combined_estimator_test",
|
||||
size = "medium",
|
||||
timeout = "long",
|
||||
srcs = ["dnn_tree_combined_estimator_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_gpu",
|
||||
"no_pip_gpu",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":dnn_tree_combined_estimator",
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "estimator_test",
|
||||
size = "medium",
|
||||
srcs = ["estimator_test.py"],
|
||||
python_version = "PY2",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_gpu",
|
||||
"no_pip_gpu",
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":estimator",
|
||||
"//tensorflow/contrib/boosted_trees:gbdt_batch",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
],
|
||||
)
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Gradient boosted trees implementation in tensorflow."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
@ -1,267 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Strategy to export custom proto formats."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
|
||||
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
|
||||
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
|
||||
from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2
|
||||
from tensorflow.contrib.learn.python.learn import export_strategy
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import loader as saved_model_loader
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"
|
||||
|
||||
|
||||
def make_custom_export_strategy(name,
|
||||
convert_fn,
|
||||
feature_columns,
|
||||
export_input_fn,
|
||||
use_core_columns=False,
|
||||
feature_engineering_fn=None,
|
||||
default_output_alternative_key=None):
|
||||
"""Makes custom exporter of GTFlow tree format.
|
||||
|
||||
Args:
|
||||
name: A string, for the name of the export strategy.
|
||||
convert_fn: A function that converts the tree proto to desired format and
|
||||
saves it to the desired location. Can be None to skip conversion.
|
||||
feature_columns: A list of feature columns.
|
||||
export_input_fn: A function that takes no arguments and returns an
|
||||
`InputFnOps`.
|
||||
use_core_columns: A boolean, whether core feature columns were used.
|
||||
feature_engineering_fn: Feature eng function to be called on the input.
|
||||
default_output_alternative_key: the name of the head to serve when an
|
||||
incoming serving request does not explicitly request a specific head.
|
||||
Not needed for single-headed models.
|
||||
|
||||
Returns:
|
||||
An `ExportStrategy`.
|
||||
"""
|
||||
base_strategy = saved_model_export_utils.make_export_strategy(
|
||||
serving_input_fn=export_input_fn,
|
||||
strip_default_attrs=True,
|
||||
default_output_alternative_key=default_output_alternative_key)
|
||||
input_fn = export_input_fn()
|
||||
features = input_fn.features
|
||||
if feature_engineering_fn is not None:
|
||||
features, _ = feature_engineering_fn(features, labels=None)
|
||||
(sorted_feature_names, dense_floats, sparse_float_indices, _, _,
|
||||
sparse_int_indices, _, _) = gbdt_batch.extract_features(
|
||||
features, feature_columns, use_core_columns)
|
||||
|
||||
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
|
||||
"""A wrapper to export to SavedModel, and convert it to other formats."""
|
||||
result_dir = base_strategy.export(estimator, export_dir,
|
||||
checkpoint_path,
|
||||
eval_result)
|
||||
with ops.Graph().as_default() as graph:
|
||||
with tf_session.Session(graph=graph) as sess:
|
||||
saved_model_loader.load(
|
||||
sess, [tag_constants.SERVING], result_dir)
|
||||
# Note: This is GTFlow internal API and might change.
|
||||
ensemble_model = graph.get_operation_by_name(
|
||||
"ensemble_model/TreeEnsembleSerialize")
|
||||
_, dfec_str = sess.run(ensemble_model.outputs)
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
dtec.ParseFromString(dfec_str)
|
||||
# Export the result in the same folder as the saved model.
|
||||
if convert_fn:
|
||||
convert_fn(dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices),
|
||||
len(sparse_int_indices), result_dir, eval_result)
|
||||
feature_importances = _get_feature_importances(
|
||||
dtec, sorted_feature_names,
|
||||
len(dense_floats),
|
||||
len(sparse_float_indices), len(sparse_int_indices))
|
||||
sorted_by_importance = sorted(
|
||||
feature_importances.items(), key=lambda x: -x[1])
|
||||
assets_dir = os.path.join(
|
||||
compat.as_bytes(result_dir), compat.as_bytes("assets.extra"))
|
||||
gfile.MakeDirs(assets_dir)
|
||||
with gfile.GFile(os.path.join(
|
||||
compat.as_bytes(assets_dir),
|
||||
compat.as_bytes("feature_importances")), "w") as f:
|
||||
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
|
||||
return result_dir
|
||||
|
||||
return export_strategy.ExportStrategy(
|
||||
name, export_fn, strip_default_attrs=True)
|
||||
|
||||
|
||||
def convert_to_universal_format(dtec, sorted_feature_names,
|
||||
num_dense, num_sparse_float,
|
||||
num_sparse_int,
|
||||
feature_name_to_proto=None):
|
||||
"""Convert GTFlow trees to universal format."""
|
||||
del num_sparse_int # unused.
|
||||
model_and_features = generic_tree_model_pb2.ModelAndFeatures()
|
||||
# TODO(jonasz): Feature descriptions should contain information about how each
|
||||
# feature is processed before it's fed to the model (e.g. bucketing
|
||||
# information). As of now, this serves as a list of features the model uses.
|
||||
for feature_name in sorted_feature_names:
|
||||
if not feature_name_to_proto:
|
||||
model_and_features.features[feature_name].SetInParent()
|
||||
else:
|
||||
model_and_features.features[feature_name].CopyFrom(
|
||||
feature_name_to_proto[feature_name])
|
||||
model = model_and_features.model
|
||||
model.ensemble.summation_combination_technique.SetInParent()
|
||||
for tree_idx in range(len(dtec.trees)):
|
||||
gtflow_tree = dtec.trees[tree_idx]
|
||||
tree_weight = dtec.tree_weights[tree_idx]
|
||||
member = model.ensemble.members.add()
|
||||
member.submodel_id.value = tree_idx
|
||||
tree = member.submodel.decision_tree
|
||||
for node_idx in range(len(gtflow_tree.nodes)):
|
||||
gtflow_node = gtflow_tree.nodes[node_idx]
|
||||
node = tree.nodes.add()
|
||||
node_type = gtflow_node.WhichOneof("node")
|
||||
node.node_id.value = node_idx
|
||||
if node_type == "leaf":
|
||||
leaf = gtflow_node.leaf
|
||||
if leaf.HasField("vector"):
|
||||
for weight in leaf.vector.value:
|
||||
new_value = node.leaf.vector.value.add()
|
||||
new_value.float_value = weight * tree_weight
|
||||
else:
|
||||
for index, weight in zip(
|
||||
leaf.sparse_vector.index, leaf.sparse_vector.value):
|
||||
new_value = node.leaf.sparse_vector.sparse_value[index]
|
||||
new_value.float_value = weight * tree_weight
|
||||
else:
|
||||
node = node.binary_node
|
||||
# Binary nodes here.
|
||||
if node_type == "dense_float_binary_split":
|
||||
split = gtflow_node.dense_float_binary_split
|
||||
feature_id = split.feature_column
|
||||
inequality_test = node.inequality_left_child_test
|
||||
inequality_test.feature_id.id.value = sorted_feature_names[feature_id]
|
||||
inequality_test.type = (
|
||||
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
|
||||
inequality_test.threshold.float_value = split.threshold
|
||||
elif node_type == "sparse_float_binary_split_default_left":
|
||||
split = gtflow_node.sparse_float_binary_split_default_left.split
|
||||
node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT)
|
||||
feature_id = split.feature_column + num_dense
|
||||
inequality_test = node.inequality_left_child_test
|
||||
inequality_test.feature_id.id.value = (
|
||||
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
|
||||
(sorted_feature_names[feature_id], split.dimension_id))
|
||||
model_and_features.features.pop(sorted_feature_names[feature_id])
|
||||
(model_and_features.features[inequality_test.feature_id.id.value]
|
||||
.SetInParent())
|
||||
inequality_test.type = (
|
||||
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
|
||||
inequality_test.threshold.float_value = split.threshold
|
||||
elif node_type == "sparse_float_binary_split_default_right":
|
||||
split = gtflow_node.sparse_float_binary_split_default_right.split
|
||||
node.default_direction = (
|
||||
generic_tree_model_pb2.BinaryNode.RIGHT)
|
||||
# TODO(nponomareva): adjust this id assignement when we allow multi-
|
||||
# column sparse tensors.
|
||||
feature_id = split.feature_column + num_dense
|
||||
inequality_test = node.inequality_left_child_test
|
||||
inequality_test.feature_id.id.value = (
|
||||
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
|
||||
(sorted_feature_names[feature_id], split.dimension_id))
|
||||
model_and_features.features.pop(sorted_feature_names[feature_id])
|
||||
(model_and_features.features[inequality_test.feature_id.id.value]
|
||||
.SetInParent())
|
||||
inequality_test.type = (
|
||||
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
|
||||
inequality_test.threshold.float_value = split.threshold
|
||||
elif node_type == "categorical_id_binary_split":
|
||||
split = gtflow_node.categorical_id_binary_split
|
||||
node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT
|
||||
feature_id = split.feature_column + num_dense + num_sparse_float
|
||||
categorical_test = (
|
||||
generic_tree_model_extensions_pb2.MatchingValuesTest())
|
||||
categorical_test.feature_id.id.value = sorted_feature_names[
|
||||
feature_id]
|
||||
matching_id = categorical_test.value.add()
|
||||
matching_id.int64_value = split.feature_id
|
||||
node.custom_left_child_test.Pack(categorical_test)
|
||||
elif (node_type == "oblivious_dense_float_binary_split" or
|
||||
node_type == "oblivious_categorical_id_binary_split"):
|
||||
raise ValueError("Universal tree format doesn't support oblivious "
|
||||
"trees")
|
||||
else:
|
||||
raise ValueError("Unexpected node type %s" % node_type)
|
||||
node.left_child_id.value = split.left_id
|
||||
node.right_child_id.value = split.right_id
|
||||
return model_and_features
|
||||
|
||||
|
||||
def _get_feature_importances(dtec, feature_names, num_dense_floats,
|
||||
num_sparse_float, num_sparse_int):
|
||||
"""Export the feature importance per feature column."""
|
||||
del num_sparse_int # Unused.
|
||||
sums = collections.defaultdict(lambda: 0)
|
||||
for tree_idx in range(len(dtec.trees)):
|
||||
tree = dtec.trees[tree_idx]
|
||||
for tree_node in tree.nodes:
|
||||
node_type = tree_node.WhichOneof("node")
|
||||
if node_type == "dense_float_binary_split":
|
||||
split = tree_node.dense_float_binary_split
|
||||
split_column = feature_names[split.feature_column]
|
||||
elif node_type == "sparse_float_binary_split_default_left":
|
||||
split = tree_node.sparse_float_binary_split_default_left.split
|
||||
split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
|
||||
feature_names[split.feature_column + num_dense_floats],
|
||||
split.dimension_id)
|
||||
elif node_type == "sparse_float_binary_split_default_right":
|
||||
split = tree_node.sparse_float_binary_split_default_right.split
|
||||
split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
|
||||
feature_names[split.feature_column + num_dense_floats],
|
||||
split.dimension_id)
|
||||
elif node_type == "categorical_id_binary_split":
|
||||
split = tree_node.categorical_id_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "oblivious_dense_float_binary_split":
|
||||
split = tree_node.oblivious_dense_float_binary_split
|
||||
split_column = feature_names[split.feature_column]
|
||||
elif node_type == "oblivious_categorical_id_binary_split":
|
||||
split = tree_node.oblivious_categorical_id_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "categorical_id_set_membership_binary_split":
|
||||
split = tree_node.categorical_id_set_membership_binary_split
|
||||
split_column = feature_names[split.feature_column + num_dense_floats +
|
||||
num_sparse_float]
|
||||
elif node_type == "leaf":
|
||||
assert tree_node.node_metadata.gain == 0
|
||||
continue
|
||||
else:
|
||||
raise ValueError("Unexpected split type %s" % node_type)
|
||||
# Apply shrinkage factor. It is important since it is not always uniform
|
||||
# across different trees.
|
||||
sums[split_column] += (
|
||||
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
|
||||
return dict(sums)
|
@ -1,364 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the conversion code and for feature importances export.
|
||||
|
||||
Tests that cover conversion from TFBT format to a tensorflow.contrib.
|
||||
decision_tree generic_tree_model format and feature importances export.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy
|
||||
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ConvertModelTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _make_trees(self):
|
||||
dtec_str = """
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trees {
|
||||
nodes {
|
||||
dense_float_binary_split {
|
||||
feature_column: 0
|
||||
threshold: 1740.0
|
||||
left_id: 1
|
||||
right_id: 2
|
||||
}
|
||||
node_metadata {
|
||||
gain: 500
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0.6
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
sparse_float_binary_split_default_left {
|
||||
split {
|
||||
feature_column: 0
|
||||
threshold: 1500.0
|
||||
left_id: 3
|
||||
right_id: 4
|
||||
}
|
||||
}
|
||||
node_metadata {
|
||||
gain: 500
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
categorical_id_binary_split {
|
||||
feature_column: 0
|
||||
feature_id: 5
|
||||
left_id: 5
|
||||
right_id: 6
|
||||
}
|
||||
node_metadata {
|
||||
gain: 500
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0.8
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
sparse_float_binary_split_default_right {
|
||||
split {
|
||||
feature_column: 1
|
||||
dimension_id:3
|
||||
threshold: -0.4
|
||||
left_id: 7
|
||||
right_id: 8
|
||||
}
|
||||
}
|
||||
node_metadata {
|
||||
gain: 3600
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0.36
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 18
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 1.0
|
||||
tree_weights: 0.1
|
||||
"""
|
||||
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(dtec_str, dtec)
|
||||
feature_columns = [
|
||||
"feature_b",
|
||||
"feature_a",
|
||||
"feature_a_m",
|
||||
"feature_d",
|
||||
]
|
||||
return dtec, feature_columns
|
||||
|
||||
def testConvertModel(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
# Assume 2 sparse float columns, one with 1 dimension, the second one with
|
||||
# 5 dimensions.
|
||||
# The feature columns in the order they were added.
|
||||
out = custom_export_strategy.convert_to_universal_format(
|
||||
dtec, feature_columns, 1, 2, 1)
|
||||
# Features a and a_m are sparse float features, a_m is multidimensional.
|
||||
expected_tree = """
|
||||
features { key: "feature_a_0" }
|
||||
features { key: "feature_a_m_3" }
|
||||
features { key: "feature_b" }
|
||||
features { key: "feature_d" }
|
||||
model {
|
||||
ensemble {
|
||||
summation_combination_technique {
|
||||
}
|
||||
members {
|
||||
submodel {
|
||||
decision_tree {
|
||||
nodes {
|
||||
node_id {
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: -1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
submodel_id {
|
||||
}
|
||||
}
|
||||
members {
|
||||
submodel {
|
||||
decision_tree {
|
||||
nodes {
|
||||
node_id {
|
||||
}
|
||||
binary_node {
|
||||
left_child_id {
|
||||
value: 1
|
||||
}
|
||||
right_child_id {
|
||||
value: 2
|
||||
}
|
||||
inequality_left_child_test {
|
||||
feature_id {
|
||||
id {
|
||||
value: "feature_b"
|
||||
}
|
||||
}
|
||||
threshold {
|
||||
float_value: 1740.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 1
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: 0.06
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 2
|
||||
}
|
||||
binary_node {
|
||||
left_child_id {
|
||||
value: 3
|
||||
}
|
||||
right_child_id {
|
||||
value: 4
|
||||
}
|
||||
inequality_left_child_test {
|
||||
feature_id {
|
||||
id {
|
||||
value: "feature_a_0"
|
||||
}
|
||||
}
|
||||
threshold {
|
||||
float_value: 1500.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 3
|
||||
}
|
||||
binary_node {
|
||||
left_child_id {
|
||||
value: 5
|
||||
}
|
||||
right_child_id {
|
||||
value: 6
|
||||
}
|
||||
default_direction: RIGHT
|
||||
custom_left_child_test {
|
||||
[type.googleapis.com/tensorflow.decision_trees.MatchingValuesTest] {
|
||||
feature_id {
|
||||
id {
|
||||
value: "feature_d"
|
||||
}
|
||||
}
|
||||
value {
|
||||
int64_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 4
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: 0.08
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 5
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: 0.05
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 6
|
||||
}
|
||||
binary_node {
|
||||
left_child_id {
|
||||
value: 7
|
||||
}
|
||||
right_child_id {
|
||||
value: 8
|
||||
}
|
||||
default_direction: RIGHT
|
||||
inequality_left_child_test {
|
||||
feature_id {
|
||||
id {
|
||||
value: "feature_a_m_3"
|
||||
}
|
||||
}
|
||||
threshold {
|
||||
float_value: -0.4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 7
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: 0.036
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
node_id {
|
||||
value: 8
|
||||
}
|
||||
leaf {
|
||||
vector {
|
||||
value {
|
||||
float_value: 1.8
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
submodel_id {
|
||||
value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}"""
|
||||
self.assertProtoEquals(expected_tree, out)
|
||||
|
||||
def testFeatureImportance(self):
|
||||
dtec, feature_columns = self._make_trees()
|
||||
feature_importances = custom_export_strategy._get_feature_importances(
|
||||
dtec, feature_columns, 1, 2, 1)
|
||||
self.assertItemsEqual(
|
||||
["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"],
|
||||
feature_importances.keys())
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4)
|
||||
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
|
||||
self.assertAlmostEqual(
|
||||
360.0, feature_importances["feature_a_m_3"], places=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -1,73 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of `head.Head` with custom loss and link function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class CustomLossHead(head_lib._RegressionHead): # pylint: disable=protected-access
|
||||
"""A Head object with custom loss function and link function."""
|
||||
|
||||
def __init__(self,
|
||||
loss_fn,
|
||||
link_fn,
|
||||
logit_dimension,
|
||||
head_name=None,
|
||||
weight_column_name=None,
|
||||
metrics_fn=None):
|
||||
"""`Head` for specifying arbitrary loss function.
|
||||
|
||||
Args:
|
||||
loss_fn: Loss function.
|
||||
link_fn: Function that converts logits to prediction.
|
||||
logit_dimension: Number of dimensions for the logits.
|
||||
head_name: name of the head. Predictions, summary, metrics keys are
|
||||
suffixed by `"/" + head_name` and the default variable scope is
|
||||
`head_name`.
|
||||
weight_column_name: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
metrics_fn: a function that takes predictions dict, labels and weights and
|
||||
returns a dictionary of metrics to be calculated.
|
||||
"""
|
||||
|
||||
def loss_wrapper(labels, logits, weight_tensor):
|
||||
if weight_tensor is None:
|
||||
weight_tensor = array_ops.ones(
|
||||
shape=[array_ops.shape(labels)[0], 1], dtype=dtypes.float32)
|
||||
weighted_loss, _ = loss_fn(labels, weight_tensor, logits)
|
||||
average_loss = math_ops.reduce_mean(weighted_loss)
|
||||
return average_loss, average_loss / math_ops.reduce_mean(weight_tensor)
|
||||
|
||||
super(CustomLossHead, self).__init__(
|
||||
loss_fn=loss_wrapper,
|
||||
link_fn=link_fn,
|
||||
head_name=head_name,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False,
|
||||
label_dimension=logit_dimension)
|
||||
|
||||
self._metrics_fn = metrics_fn
|
||||
|
||||
def _metrics(self, eval_loss, predictions, labels, weights):
|
||||
if self._metrics_fn is not None:
|
||||
return self._metrics_fn(predictions, labels, weights)
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utill functions for distillation loss.
|
||||
|
||||
The distillation loss_fn will be called with the following:
|
||||
|
||||
Args:
|
||||
dnn_logits: Tensor of logits from the dnn, treated as the "target". This will
|
||||
be the output of a call to tf.stop_gradient().
|
||||
tree_logits: Tensor of logits from the tree, treated as the "predictions".
|
||||
example_weights: Tensor of example weights, or a single scalar.
|
||||
|
||||
Returns:
|
||||
A scalar indicating the reduced loss for that batch of examples.
|
||||
|
||||
Note: we calls the loss_fn defined in contrib head, which is computing two
|
||||
losses, first one for training and second one for reporting. We only take the
|
||||
first one here.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
def _logits_to_label_for_tree(logits, n_classes):
|
||||
if n_classes == 2:
|
||||
return math_ops.sigmoid(logits)
|
||||
else:
|
||||
return nn.softmax(logits)
|
||||
|
||||
|
||||
def create_dnn_to_tree_squared_loss_fn(n_classes):
|
||||
"""Returns a squared loss function for dnn to tree distillation."""
|
||||
|
||||
def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights):
|
||||
return head_lib._mean_squared_loss( # pylint: disable=protected-access
|
||||
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
|
||||
logits=_logits_to_label_for_tree(tree_logits, n_classes),
|
||||
weights=example_weights)[0]
|
||||
|
||||
return _dnn_to_tree_squared_loss
|
||||
|
||||
|
||||
def create_dnn_to_tree_cross_entropy_loss_fn(n_classes):
|
||||
"""Returns a cross entropy loss function for dnn to tree distillation."""
|
||||
|
||||
def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights):
|
||||
if n_classes == 2:
|
||||
return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access
|
||||
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
|
||||
logits=tree_logits,
|
||||
weights=example_weights)[0]
|
||||
else:
|
||||
return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
|
||||
logits=tree_logits,
|
||||
weights=example_weights)[0]
|
||||
|
||||
return _dnn_to_tree_cross_entropy_loss
|
@ -1,859 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TensorFlow estimators for combined DNN + GBDT training model.
|
||||
|
||||
The combined model trains a DNN first, then trains boosted trees to boost the
|
||||
logits of the DNN. The input layer of the DNN (including the embeddings learned
|
||||
over sparse features) can optionally be provided to the boosted trees as
|
||||
an additional input feature.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import model
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
|
||||
from tensorflow.contrib.boosted_trees.python.ops import model_ops
|
||||
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
|
||||
from tensorflow.contrib.layers.python.layers import optimizers
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.python.feature_column import feature_column_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
_DNN_LEARNING_RATE = 0.001
|
||||
|
||||
|
||||
def _get_optimizer(optimizer):
|
||||
if callable(optimizer):
|
||||
return optimizer()
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
|
||||
def _add_hidden_layer_summary(value, tag):
|
||||
summary.scalar("%s_fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
||||
summary.histogram("%s_activation" % tag, value)
|
||||
|
||||
|
||||
def _dnn_tree_combined_model_fn(
|
||||
features,
|
||||
labels,
|
||||
mode,
|
||||
head,
|
||||
dnn_hidden_units,
|
||||
dnn_feature_columns,
|
||||
tree_learner_config,
|
||||
num_trees,
|
||||
tree_examples_per_layer,
|
||||
config=None,
|
||||
dnn_optimizer="Adagrad",
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
dnn_input_layer_partitioner=None,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=10000,
|
||||
predict_with_tree_only=False,
|
||||
tree_feature_columns=None,
|
||||
tree_center_bias=False,
|
||||
dnn_to_tree_distillation_param=None,
|
||||
use_core_versions=False,
|
||||
output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
|
||||
override_global_step_value=None):
|
||||
"""DNN and GBDT combined model_fn.
|
||||
|
||||
Args:
|
||||
features: `dict` of `Tensor` objects.
|
||||
labels: Labels used to train on.
|
||||
mode: Mode we are in. (TRAIN/EVAL/INFER)
|
||||
head: A `Head` instance.
|
||||
dnn_hidden_units: List of hidden units per layer.
|
||||
dnn_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's DNN.
|
||||
tree_learner_config: A config for the tree learner.
|
||||
num_trees: Number of trees to grow model to after training DNN.
|
||||
tree_examples_per_layer: Number of examples to accumulate before
|
||||
growing the tree a layer. This value has a big impact on model
|
||||
quality and should be set equal to the number of examples in
|
||||
training dataset if possible. It can also be a function that computes
|
||||
the number of examples based on the depth of the layer that's
|
||||
being built.
|
||||
config: `RunConfig` of the estimator.
|
||||
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN. If `None`, will use the Adagrad
|
||||
optimizer with default learning rate of 0.001.
|
||||
dnn_activation_fn: Activation function applied to each layer of the DNN.
|
||||
If `None`, will use `tf.nn.relu`.
|
||||
dnn_dropout: When not `None`, the probability to drop out a given
|
||||
unit in the DNN.
|
||||
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
|
||||
Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
|
||||
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
|
||||
as a feature to the tree.
|
||||
dnn_steps_to_train: Number of steps to train dnn for before switching
|
||||
to gbdt.
|
||||
predict_with_tree_only: Whether to use only the tree model output as the
|
||||
final prediction.
|
||||
tree_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's boosted trees. If dnn_input_layer_to_tree is
|
||||
set to True, these features are in addition to dnn_feature_columns.
|
||||
tree_center_bias: Whether a separate tree should be created for
|
||||
first fitting the bias.
|
||||
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
|
||||
float defines the weight of the distillation loss, and the loss_fn, for
|
||||
computing distillation loss, takes dnn_logits, tree_logits and weight
|
||||
tensor. If the entire tuple is None, no distillation will be applied. If
|
||||
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
|
||||
loss be default. When distillation is applied, `predict_with_tree_only`
|
||||
will be set to True.
|
||||
use_core_versions: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
|
||||
(new interface).
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
|
||||
Returns:
|
||||
A `ModelFnOps` object.
|
||||
Raises:
|
||||
ValueError: if inputs are not valid.
|
||||
"""
|
||||
if not isinstance(features, dict):
|
||||
raise ValueError("features should be a dictionary of `Tensor`s. "
|
||||
"Given type: {}".format(type(features)))
|
||||
|
||||
if not dnn_feature_columns:
|
||||
raise ValueError("dnn_feature_columns must be specified")
|
||||
|
||||
if dnn_to_tree_distillation_param:
|
||||
if not predict_with_tree_only:
|
||||
logging.warning("update predict_with_tree_only to True since distillation"
|
||||
"is specified.")
|
||||
predict_with_tree_only = True
|
||||
|
||||
# Build DNN Logits.
|
||||
dnn_parent_scope = "dnn"
|
||||
dnn_partitioner = dnn_input_layer_partitioner or (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))
|
||||
|
||||
if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and
|
||||
not use_core_versions):
|
||||
raise ValueError("You must use core versions with Estimator Spec")
|
||||
global_step = training_util.get_global_step()
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
dnn_parent_scope,
|
||||
values=tuple(six.itervalues(features)),
|
||||
partitioner=dnn_partitioner):
|
||||
|
||||
with variable_scope.variable_scope(
|
||||
"input_from_feature_columns",
|
||||
values=tuple(six.itervalues(features)),
|
||||
partitioner=dnn_partitioner) as input_layer_scope:
|
||||
if use_core_versions:
|
||||
input_layer = feature_column_lib.input_layer(
|
||||
features=features,
|
||||
feature_columns=dnn_feature_columns,
|
||||
weight_collections=[dnn_parent_scope])
|
||||
else:
|
||||
input_layer = layers.input_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=dnn_feature_columns,
|
||||
weight_collections=[dnn_parent_scope],
|
||||
scope=input_layer_scope)
|
||||
def dnn_logits_fn():
|
||||
"""Builds the logits from the input layer."""
|
||||
previous_layer = input_layer
|
||||
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
|
||||
with variable_scope.variable_scope(
|
||||
"hiddenlayer_%d" % layer_id,
|
||||
values=(previous_layer,)) as hidden_layer_scope:
|
||||
net = layers.fully_connected(
|
||||
previous_layer,
|
||||
num_hidden_units,
|
||||
activation_fn=dnn_activation_fn,
|
||||
variables_collections=[dnn_parent_scope],
|
||||
scope=hidden_layer_scope)
|
||||
if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN:
|
||||
net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout))
|
||||
_add_hidden_layer_summary(net, hidden_layer_scope.name)
|
||||
previous_layer = net
|
||||
with variable_scope.variable_scope(
|
||||
"logits", values=(previous_layer,)) as logits_scope:
|
||||
dnn_logits = layers.fully_connected(
|
||||
previous_layer,
|
||||
head.logits_dimension,
|
||||
activation_fn=None,
|
||||
variables_collections=[dnn_parent_scope],
|
||||
scope=logits_scope)
|
||||
_add_hidden_layer_summary(dnn_logits, logits_scope.name)
|
||||
return dnn_logits
|
||||
if predict_with_tree_only and mode == model_fn.ModeKeys.INFER:
|
||||
dnn_logits = array_ops.constant(0.0)
|
||||
dnn_train_op_fn = control_flow_ops.no_op
|
||||
elif predict_with_tree_only and mode == model_fn.ModeKeys.EVAL:
|
||||
dnn_logits = control_flow_ops.cond(
|
||||
global_step > dnn_steps_to_train,
|
||||
lambda: array_ops.constant(0.0),
|
||||
dnn_logits_fn)
|
||||
dnn_train_op_fn = control_flow_ops.no_op
|
||||
else:
|
||||
dnn_logits = dnn_logits_fn()
|
||||
def dnn_train_op_fn(loss):
|
||||
"""Returns the op to optimize the loss."""
|
||||
return optimizers.optimize_loss(
|
||||
loss=loss,
|
||||
global_step=training_util.get_global_step(),
|
||||
learning_rate=_DNN_LEARNING_RATE,
|
||||
optimizer=_get_optimizer(dnn_optimizer),
|
||||
name=dnn_parent_scope,
|
||||
variables=ops.get_collection(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope),
|
||||
# Empty summaries to prevent optimizers from logging training_loss.
|
||||
summaries=[])
|
||||
|
||||
# Build Tree Logits.
|
||||
with ops.device(global_step.device):
|
||||
ensemble_handle = model_ops.tree_ensemble_variable(
|
||||
stamp_token=0,
|
||||
tree_ensemble_config="", # Initialize an empty ensemble.
|
||||
name="ensemble_model")
|
||||
|
||||
tree_features = features.copy()
|
||||
if dnn_input_layer_to_tree:
|
||||
tree_features["dnn_input_layer"] = input_layer
|
||||
tree_feature_columns.append(layers.real_valued_column("dnn_input_layer"))
|
||||
gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
|
||||
is_chief=config.is_chief,
|
||||
num_ps_replicas=config.num_ps_replicas,
|
||||
ensemble_handle=ensemble_handle,
|
||||
center_bias=tree_center_bias,
|
||||
examples_per_layer=tree_examples_per_layer,
|
||||
learner_config=tree_learner_config,
|
||||
feature_columns=tree_feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=tree_features,
|
||||
use_core_columns=use_core_versions)
|
||||
|
||||
with ops.name_scope("gbdt"):
|
||||
predictions_dict = gbdt_model.predict(mode)
|
||||
tree_logits = predictions_dict["predictions"]
|
||||
|
||||
def _tree_train_op_fn(loss):
|
||||
"""Returns the op to optimize the loss."""
|
||||
if dnn_to_tree_distillation_param:
|
||||
loss_weight, loss_fn = dnn_to_tree_distillation_param
|
||||
# pylint: disable=protected-access
|
||||
if use_core_versions:
|
||||
weight_tensor = head_lib._weight_tensor(features, head._weight_column)
|
||||
else:
|
||||
weight_tensor = head_lib._weight_tensor(
|
||||
features, head.weight_column_name)
|
||||
# pylint: enable=protected-access
|
||||
dnn_logits_fixed = array_ops.stop_gradient(dnn_logits)
|
||||
|
||||
if loss_fn is None:
|
||||
# we create the loss_fn similar to the head loss_fn for
|
||||
# multi_class_head used previously as the default one.
|
||||
n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension
|
||||
loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn(
|
||||
n_classes)
|
||||
|
||||
dnn_to_tree_distillation_loss = loss_weight * loss_fn(
|
||||
dnn_logits_fixed, tree_logits, weight_tensor)
|
||||
summary.scalar("dnn_to_tree_distillation_loss",
|
||||
dnn_to_tree_distillation_loss)
|
||||
loss += dnn_to_tree_distillation_loss
|
||||
|
||||
update_op = gbdt_model.train(loss, predictions_dict, labels)
|
||||
with ops.control_dependencies(
|
||||
[update_op]), (ops.colocate_with(global_step)):
|
||||
update_op = state_ops.assign_add(global_step, 1).op
|
||||
return update_op
|
||||
|
||||
if predict_with_tree_only:
|
||||
if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER:
|
||||
tree_train_logits = tree_logits
|
||||
else:
|
||||
tree_train_logits = control_flow_ops.cond(
|
||||
global_step > dnn_steps_to_train,
|
||||
lambda: tree_logits,
|
||||
lambda: dnn_logits)
|
||||
else:
|
||||
tree_train_logits = dnn_logits + tree_logits
|
||||
|
||||
def _no_train_op_fn(loss):
|
||||
"""Returns a no-op."""
|
||||
del loss
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
if tree_center_bias:
|
||||
num_trees += 1
|
||||
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
|
||||
|
||||
if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_no_train_op_fn,
|
||||
logits=tree_train_logits)
|
||||
if mode != model_fn.ModeKeys.TRAIN:
|
||||
return model_fn_ops
|
||||
dnn_train_op = head.create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=dnn_train_op_fn,
|
||||
logits=dnn_logits).train_op
|
||||
tree_train_op = head.create_model_fn_ops(
|
||||
features=tree_features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_tree_train_op_fn,
|
||||
logits=tree_train_logits).train_op
|
||||
|
||||
# Add the hooks
|
||||
model_fn_ops.training_hooks.extend([
|
||||
trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
|
||||
tree_train_op),
|
||||
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
|
||||
finalized_trees,
|
||||
override_global_step_value)
|
||||
])
|
||||
return model_fn_ops
|
||||
|
||||
elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC:
|
||||
fusion_spec = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_no_train_op_fn,
|
||||
logits=tree_train_logits)
|
||||
if mode != model_fn.ModeKeys.TRAIN:
|
||||
return fusion_spec
|
||||
dnn_spec = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=dnn_train_op_fn,
|
||||
logits=dnn_logits)
|
||||
tree_spec = head.create_estimator_spec(
|
||||
features=tree_features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_tree_train_op_fn,
|
||||
logits=tree_train_logits)
|
||||
|
||||
training_hooks = [
|
||||
trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
|
||||
tree_spec.train_op),
|
||||
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
|
||||
finalized_trees,
|
||||
override_global_step_value)
|
||||
]
|
||||
fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
|
||||
list(fusion_spec.training_hooks))
|
||||
return fusion_spec
|
||||
|
||||
|
||||
class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
|
||||
"""A classifier that uses a combined DNN/GBDT model."""
|
||||
|
||||
def __init__(self,
|
||||
dnn_hidden_units,
|
||||
dnn_feature_columns,
|
||||
tree_learner_config,
|
||||
num_trees,
|
||||
tree_examples_per_layer,
|
||||
n_classes=2,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_name=None,
|
||||
label_keys=None,
|
||||
feature_engineering_fn=None,
|
||||
dnn_optimizer="Adagrad",
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
dnn_input_layer_partitioner=None,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=10000,
|
||||
predict_with_tree_only=False,
|
||||
tree_feature_columns=None,
|
||||
tree_center_bias=False,
|
||||
dnn_to_tree_distillation_param=None,
|
||||
use_core_versions=False,
|
||||
override_global_step_value=None):
|
||||
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
|
||||
|
||||
Args:
|
||||
dnn_hidden_units: List of hidden units per layer for DNN.
|
||||
dnn_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's DNN.
|
||||
tree_learner_config: A config for the tree learner.
|
||||
num_trees: Number of trees to grow model to after training DNN.
|
||||
tree_examples_per_layer: Number of examples to accumulate before
|
||||
growing the tree a layer. This value has a big impact on model
|
||||
quality and should be set equal to the number of examples in
|
||||
training dataset if possible. It can also be a function that computes
|
||||
the number of examples based on the depth of the layer that's
|
||||
being built.
|
||||
n_classes: The number of label classes.
|
||||
weight_column_name: The name of weight column.
|
||||
model_dir: Directory for model exports.
|
||||
config: `RunConfig` of the estimator.
|
||||
label_name: String, name of the key in label dict. Can be null if label
|
||||
is a tensor (single headed models).
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN. If `None`, will use the Adagrad
|
||||
optimizer with default learning rate.
|
||||
dnn_activation_fn: Activation function applied to each layer of the DNN.
|
||||
If `None`, will use `tf.nn.relu`.
|
||||
dnn_dropout: When not `None`, the probability to drop out a given
|
||||
unit in the DNN.
|
||||
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
|
||||
Defaults to `min_max_variable_partitioner` with `min_slice_size`
|
||||
64 << 20.
|
||||
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
|
||||
as a feature to the tree.
|
||||
dnn_steps_to_train: Number of steps to train dnn for before switching
|
||||
to gbdt.
|
||||
predict_with_tree_only: Whether to use only the tree model output as the
|
||||
final prediction.
|
||||
tree_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's boosted trees. If dnn_input_layer_to_tree is
|
||||
set to True, these features are in addition to dnn_feature_columns.
|
||||
tree_center_bias: Whether a separate tree should be created for
|
||||
first fitting the bias.
|
||||
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
|
||||
float defines the weight of the distillation loss, and the loss_fn, for
|
||||
computing distillation loss, takes dnn_logits, tree_logits and weight
|
||||
tensor. If the entire tuple is None, no distillation will be applied. If
|
||||
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
|
||||
loss be default. When distillation is applied, `predict_with_tree_only`
|
||||
will be set to True.
|
||||
use_core_versions: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
"""
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
label_name=label_name,
|
||||
label_keys=label_keys,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_tree_combined_model_fn(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
dnn_hidden_units=dnn_hidden_units,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
tree_learner_config=tree_learner_config,
|
||||
num_trees=num_trees,
|
||||
tree_examples_per_layer=tree_examples_per_layer,
|
||||
config=config,
|
||||
dnn_optimizer=dnn_optimizer,
|
||||
dnn_activation_fn=dnn_activation_fn,
|
||||
dnn_dropout=dnn_dropout,
|
||||
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
|
||||
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
|
||||
dnn_steps_to_train=dnn_steps_to_train,
|
||||
predict_with_tree_only=predict_with_tree_only,
|
||||
tree_feature_columns=tree_feature_columns,
|
||||
tree_center_bias=tree_center_bias,
|
||||
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
|
||||
use_core_versions=use_core_versions,
|
||||
override_global_step_value=override_global_step_value)
|
||||
|
||||
super(DNNBoostedTreeCombinedClassifier, self).__init__(
|
||||
model_fn=_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
|
||||
"""A regressor that uses a combined DNN/GBDT model."""
|
||||
|
||||
def __init__(self,
|
||||
dnn_hidden_units,
|
||||
dnn_feature_columns,
|
||||
tree_learner_config,
|
||||
num_trees,
|
||||
tree_examples_per_layer,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_name=None,
|
||||
label_dimension=1,
|
||||
feature_engineering_fn=None,
|
||||
dnn_optimizer="Adagrad",
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
dnn_input_layer_partitioner=None,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=10000,
|
||||
predict_with_tree_only=False,
|
||||
tree_feature_columns=None,
|
||||
tree_center_bias=False,
|
||||
dnn_to_tree_distillation_param=None,
|
||||
use_core_versions=False,
|
||||
override_global_step_value=None):
|
||||
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
|
||||
|
||||
Args:
|
||||
dnn_hidden_units: List of hidden units per layer for DNN.
|
||||
dnn_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's DNN.
|
||||
tree_learner_config: A config for the tree learner.
|
||||
num_trees: Number of trees to grow model to after training DNN.
|
||||
tree_examples_per_layer: Number of examples to accumulate before
|
||||
growing the tree a layer. This value has a big impact on model
|
||||
quality and should be set equal to the number of examples in
|
||||
training dataset if possible. It can also be a function that computes
|
||||
the number of examples based on the depth of the layer that's
|
||||
being built.
|
||||
weight_column_name: The name of weight column.
|
||||
model_dir: Directory for model exports.
|
||||
config: `RunConfig` of the estimator.
|
||||
label_name: String, name of the key in label dict. Can be null if label
|
||||
is a tensor (single headed models).
|
||||
label_dimension: Number of regression labels per example. This is the size
|
||||
of the last dimension of the labels `Tensor` (typically, this has shape
|
||||
`[batch_size, label_dimension]`).
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN. If `None`, will use the Adagrad
|
||||
optimizer with default learning rate.
|
||||
dnn_activation_fn: Activation function applied to each layer of the DNN.
|
||||
If `None`, will use `tf.nn.relu`.
|
||||
dnn_dropout: When not `None`, the probability to drop out a given
|
||||
unit in the DNN.
|
||||
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
|
||||
Defaults to `min_max_variable_partitioner` with `min_slice_size`
|
||||
64 << 20.
|
||||
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
|
||||
as a feature to the tree.
|
||||
dnn_steps_to_train: Number of steps to train dnn for before switching
|
||||
to gbdt.
|
||||
predict_with_tree_only: Whether to use only the tree model output as the
|
||||
final prediction.
|
||||
tree_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's boosted trees. If dnn_input_layer_to_tree is
|
||||
set to True, these features are in addition to dnn_feature_columns.
|
||||
tree_center_bias: Whether a separate tree should be created for
|
||||
first fitting the bias.
|
||||
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
|
||||
float defines the weight of the distillation loss, and the loss_fn, for
|
||||
computing distillation loss, takes dnn_logits, tree_logits and weight
|
||||
tensor. If the entire tuple is None, no distillation will be applied. If
|
||||
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
|
||||
loss be default. When distillation is applied, `predict_with_tree_only`
|
||||
will be set to True.
|
||||
use_core_versions: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
"""
|
||||
head = head_lib.regression_head(
|
||||
label_name=label_name,
|
||||
label_dimension=label_dimension,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
|
||||
# num_classes needed for GradientBoostedDecisionTreeModel
|
||||
if label_dimension == 1:
|
||||
tree_learner_config.num_classes = 2
|
||||
else:
|
||||
tree_learner_config.num_classes = label_dimension
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_tree_combined_model_fn(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
dnn_hidden_units=dnn_hidden_units,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
tree_learner_config=tree_learner_config,
|
||||
num_trees=num_trees,
|
||||
tree_examples_per_layer=tree_examples_per_layer,
|
||||
config=config,
|
||||
dnn_optimizer=dnn_optimizer,
|
||||
dnn_activation_fn=dnn_activation_fn,
|
||||
dnn_dropout=dnn_dropout,
|
||||
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
|
||||
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
|
||||
dnn_steps_to_train=dnn_steps_to_train,
|
||||
predict_with_tree_only=predict_with_tree_only,
|
||||
tree_feature_columns=tree_feature_columns,
|
||||
tree_center_bias=tree_center_bias,
|
||||
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
|
||||
use_core_versions=use_core_versions,
|
||||
override_global_step_value=override_global_step_value)
|
||||
|
||||
super(DNNBoostedTreeCombinedRegressor, self).__init__(
|
||||
model_fn=_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
|
||||
"""An estimator that uses a combined DNN/GBDT model.
|
||||
|
||||
Useful for training with user specified `Head`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dnn_hidden_units,
|
||||
dnn_feature_columns,
|
||||
tree_learner_config,
|
||||
num_trees,
|
||||
tree_examples_per_layer,
|
||||
head,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
feature_engineering_fn=None,
|
||||
dnn_optimizer="Adagrad",
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
dnn_input_layer_partitioner=None,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=10000,
|
||||
predict_with_tree_only=False,
|
||||
tree_feature_columns=None,
|
||||
tree_center_bias=False,
|
||||
dnn_to_tree_distillation_param=None,
|
||||
use_core_versions=False,
|
||||
override_global_step_value=None):
|
||||
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
|
||||
|
||||
Args:
|
||||
dnn_hidden_units: List of hidden units per layer for DNN.
|
||||
dnn_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's DNN.
|
||||
tree_learner_config: A config for the tree learner.
|
||||
num_trees: Number of trees to grow model to after training DNN.
|
||||
tree_examples_per_layer: Number of examples to accumulate before
|
||||
growing the tree a layer. This value has a big impact on model
|
||||
quality and should be set equal to the number of examples in
|
||||
training dataset if possible. It can also be a function that computes
|
||||
the number of examples based on the depth of the layer that's
|
||||
being built.
|
||||
head: `Head` instance.
|
||||
model_dir: Directory for model exports.
|
||||
config: `RunConfig` of the estimator.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN. If `None`, will use the Adagrad
|
||||
optimizer with default learning rate.
|
||||
dnn_activation_fn: Activation function applied to each layer of the DNN.
|
||||
If `None`, will use `tf.nn.relu`.
|
||||
dnn_dropout: When not `None`, the probability to drop out a given
|
||||
unit in the DNN.
|
||||
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
|
||||
Defaults to `min_max_variable_partitioner` with `min_slice_size`
|
||||
64 << 20.
|
||||
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
|
||||
as a feature to the tree.
|
||||
dnn_steps_to_train: Number of steps to train dnn for before switching
|
||||
to gbdt.
|
||||
predict_with_tree_only: Whether to use only the tree model output as the
|
||||
final prediction.
|
||||
tree_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's boosted trees. If dnn_input_layer_to_tree is
|
||||
set to True, these features are in addition to dnn_feature_columns.
|
||||
tree_center_bias: Whether a separate tree should be created for
|
||||
first fitting the bias.
|
||||
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
|
||||
float defines the weight of the distillation loss, and the loss_fn, for
|
||||
computing distillation loss, takes dnn_logits, tree_logits and weight
|
||||
tensor. If the entire tuple is None, no distillation will be applied. If
|
||||
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
|
||||
loss be default. When distillation is applied, `predict_with_tree_only`
|
||||
will be set to True.
|
||||
use_core_versions: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
"""
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_tree_combined_model_fn(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
dnn_hidden_units=dnn_hidden_units,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
tree_learner_config=tree_learner_config,
|
||||
num_trees=num_trees,
|
||||
tree_examples_per_layer=tree_examples_per_layer,
|
||||
config=config,
|
||||
dnn_optimizer=dnn_optimizer,
|
||||
dnn_activation_fn=dnn_activation_fn,
|
||||
dnn_dropout=dnn_dropout,
|
||||
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
|
||||
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
|
||||
dnn_steps_to_train=dnn_steps_to_train,
|
||||
predict_with_tree_only=predict_with_tree_only,
|
||||
tree_feature_columns=tree_feature_columns,
|
||||
tree_center_bias=tree_center_bias,
|
||||
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
|
||||
use_core_versions=use_core_versions,
|
||||
override_global_step_value=override_global_step_value)
|
||||
|
||||
super(DNNBoostedTreeCombinedEstimator, self).__init__(
|
||||
model_fn=_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
|
||||
"""Initializes a core version of DNNBoostedTreeCombinedEstimator.
|
||||
|
||||
Args:
|
||||
dnn_hidden_units: List of hidden units per layer for DNN.
|
||||
dnn_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's DNN.
|
||||
tree_learner_config: A config for the tree learner.
|
||||
num_trees: Number of trees to grow model to after training DNN.
|
||||
tree_examples_per_layer: Number of examples to accumulate before
|
||||
growing the tree a layer. This value has a big impact on model
|
||||
quality and should be set equal to the number of examples in
|
||||
training dataset if possible. It can also be a function that computes
|
||||
the number of examples based on the depth of the layer that's
|
||||
being built.
|
||||
head: `Head` instance.
|
||||
model_dir: Directory for model exports.
|
||||
config: `RunConfig` of the estimator.
|
||||
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||
optimizer to use for training the DNN. If `None`, will use the Adagrad
|
||||
optimizer with default learning rate.
|
||||
dnn_activation_fn: Activation function applied to each layer of the DNN.
|
||||
If `None`, will use `tf.nn.relu`.
|
||||
dnn_dropout: When not `None`, the probability to drop out a given
|
||||
unit in the DNN.
|
||||
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
|
||||
Defaults to `min_max_variable_partitioner` with `min_slice_size`
|
||||
64 << 20.
|
||||
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
|
||||
as a feature to the tree.
|
||||
dnn_steps_to_train: Number of steps to train dnn for before switching
|
||||
to gbdt.
|
||||
predict_with_tree_only: Whether to use only the tree model output as the
|
||||
final prediction.
|
||||
tree_feature_columns: An iterable containing all the feature columns
|
||||
used by the model's boosted trees. If dnn_input_layer_to_tree is
|
||||
set to True, these features are in addition to dnn_feature_columns.
|
||||
tree_center_bias: Whether a separate tree should be created for
|
||||
first fitting the bias.
|
||||
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
|
||||
float defines the weight of the distillation loss, and the loss_fn, for
|
||||
computing distillation loss, takes dnn_logits, tree_logits and weight
|
||||
tensor. If the entire tuple is None, no distillation will be applied. If
|
||||
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
|
||||
loss be default. When distillation is applied, `predict_with_tree_only`
|
||||
will be set to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dnn_hidden_units,
|
||||
dnn_feature_columns,
|
||||
tree_learner_config,
|
||||
num_trees,
|
||||
tree_examples_per_layer,
|
||||
head,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
dnn_optimizer="Adagrad",
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
dnn_input_layer_partitioner=None,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=10000,
|
||||
predict_with_tree_only=False,
|
||||
tree_feature_columns=None,
|
||||
tree_center_bias=False,
|
||||
dnn_to_tree_distillation_param=None):
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_tree_combined_model_fn(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head,
|
||||
dnn_hidden_units=dnn_hidden_units,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
tree_learner_config=tree_learner_config,
|
||||
num_trees=num_trees,
|
||||
tree_examples_per_layer=tree_examples_per_layer,
|
||||
config=config,
|
||||
dnn_optimizer=dnn_optimizer,
|
||||
dnn_activation_fn=dnn_activation_fn,
|
||||
dnn_dropout=dnn_dropout,
|
||||
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
|
||||
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
|
||||
dnn_steps_to_train=dnn_steps_to_train,
|
||||
predict_with_tree_only=predict_with_tree_only,
|
||||
tree_feature_columns=tree_feature_columns,
|
||||
tree_center_bias=tree_center_bias,
|
||||
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
|
||||
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
|
||||
use_core_versions=True,
|
||||
override_global_step_value=None)
|
||||
|
||||
super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
@ -1,249 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for combined DNN + GBDT estimators."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.python.estimator import exporter
|
||||
from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
|
||||
|
||||
def _train_input_fn():
|
||||
features = {
|
||||
"x": constant_op.constant([[2.], [1.], [1.]])
|
||||
}
|
||||
label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32)
|
||||
return features, label
|
||||
|
||||
|
||||
def _eval_input_fn():
|
||||
features = {
|
||||
"x": constant_op.constant([[1.], [2.], [2.]])
|
||||
}
|
||||
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
|
||||
return features, label
|
||||
|
||||
|
||||
class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testClassifierContract(self):
|
||||
estimator_test_utils.assert_estimator_contract(
|
||||
self, estimator.DNNBoostedTreeCombinedClassifier)
|
||||
|
||||
def testRegressorContract(self):
|
||||
estimator_test_utils.assert_estimator_contract(
|
||||
self, estimator.DNNBoostedTreeCombinedRegressor)
|
||||
|
||||
def testEstimatorContract(self):
|
||||
estimator_test_utils.assert_estimator_contract(
|
||||
self, estimator.DNNBoostedTreeCombinedEstimator)
|
||||
|
||||
def testNoDNNFeatureColumns(self):
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"dnn_feature_columns must be specified"):
|
||||
classifier = estimator.DNNBoostedTreeCombinedClassifier(
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
n_classes=2)
|
||||
classifier.fit(input_fn=_train_input_fn, steps=5)
|
||||
|
||||
def testFitAndEvaluateDontThrowException(self):
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.constraints.max_tree_depth = 1
|
||||
model_dir = tempfile.mkdtemp()
|
||||
config = run_config.RunConfig()
|
||||
|
||||
classifier = estimator.DNNBoostedTreeCombinedClassifier(
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[feature_column.real_valued_column("x")],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
n_classes=2,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
dnn_steps_to_train=10,
|
||||
dnn_input_layer_to_tree=False,
|
||||
tree_feature_columns=[feature_column.real_valued_column("x")])
|
||||
|
||||
classifier.fit(input_fn=_train_input_fn, steps=15)
|
||||
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
|
||||
|
||||
def testFitAndEvaluateWithDistillation(self):
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.constraints.max_tree_depth = 1
|
||||
model_dir = tempfile.mkdtemp()
|
||||
config = run_config.RunConfig()
|
||||
|
||||
classifier = estimator.DNNBoostedTreeCombinedClassifier(
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[feature_column.real_valued_column("x")],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
n_classes=2,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
dnn_steps_to_train=10,
|
||||
dnn_input_layer_to_tree=False,
|
||||
tree_feature_columns=[feature_column.real_valued_column("x")],
|
||||
dnn_to_tree_distillation_param=(1, None))
|
||||
|
||||
classifier.fit(input_fn=_train_input_fn, steps=15)
|
||||
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
|
||||
|
||||
|
||||
class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _assert_checkpoint(self, model_dir, global_step):
|
||||
reader = checkpoint_utils.load_checkpoint(model_dir)
|
||||
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
|
||||
|
||||
def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self):
|
||||
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.constraints.max_tree_depth = 3
|
||||
model_dir = tempfile.mkdtemp()
|
||||
config = run_config.RunConfig()
|
||||
|
||||
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
|
||||
head=head_fn,
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[core_feature_column.numeric_column("x")],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
dnn_steps_to_train=10,
|
||||
dnn_input_layer_to_tree=False,
|
||||
tree_feature_columns=[core_feature_column.numeric_column("x")])
|
||||
|
||||
# Train for a few steps.
|
||||
est.train(input_fn=_train_input_fn, steps=1000)
|
||||
# 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
|
||||
self._assert_checkpoint(est.model_dir, global_step=14)
|
||||
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
|
||||
self.assertLess(0.5, res["auc"])
|
||||
est.predict(input_fn=_eval_input_fn)
|
||||
|
||||
def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self):
|
||||
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.constraints.max_tree_depth = 3
|
||||
model_dir = tempfile.mkdtemp()
|
||||
config = run_config.RunConfig()
|
||||
|
||||
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
|
||||
head=head_fn,
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[core_feature_column.numeric_column("x")],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
dnn_steps_to_train=10,
|
||||
dnn_input_layer_to_tree=True,
|
||||
tree_feature_columns=[])
|
||||
|
||||
# Train for a few steps.
|
||||
est.train(input_fn=_train_input_fn, steps=1000)
|
||||
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
|
||||
self.assertLess(0.5, res["auc"])
|
||||
est.predict(input_fn=_eval_input_fn)
|
||||
|
||||
def testTrainEvaluateWithDnnForInputAndTreeForPredict(self):
|
||||
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
|
||||
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.num_classes = 2
|
||||
learner_config.constraints.max_tree_depth = 3
|
||||
model_dir = tempfile.mkdtemp()
|
||||
config = run_config.RunConfig()
|
||||
|
||||
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
|
||||
head=head_fn,
|
||||
dnn_hidden_units=[1],
|
||||
dnn_feature_columns=[core_feature_column.numeric_column("x")],
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=1,
|
||||
tree_examples_per_layer=3,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
dnn_steps_to_train=10,
|
||||
dnn_input_layer_to_tree=True,
|
||||
predict_with_tree_only=True,
|
||||
dnn_to_tree_distillation_param=(0.5, None),
|
||||
tree_feature_columns=[])
|
||||
|
||||
# Train for a few steps.
|
||||
est.train(input_fn=_train_input_fn, steps=1000)
|
||||
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
|
||||
self.assertLess(0.5, res["auc"])
|
||||
est.predict(input_fn=_eval_input_fn)
|
||||
serving_input_fn = (
|
||||
export.build_parsing_serving_input_receiver_fn(
|
||||
feature_spec={"x": parsing_ops.FixedLenFeature(
|
||||
[1], dtype=dtypes.float32)}))
|
||||
base_exporter = exporter.FinalExporter(
|
||||
name="Servo",
|
||||
serving_input_receiver_fn=serving_input_fn,
|
||||
assets_extra=None)
|
||||
export_path = os.path.join(model_dir, "export")
|
||||
base_exporter.export(
|
||||
est,
|
||||
export_path=export_path,
|
||||
checkpoint_path=None,
|
||||
eval_result={},
|
||||
is_the_final_export=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -1,837 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""GTFlow Estimator definition."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import model
|
||||
from tensorflow.contrib.boosted_trees.python.utils import losses
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
|
||||
from tensorflow.python.estimator.canned import head as core_head_lib
|
||||
from tensorflow.python.estimator import estimator as core_estimator
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.losses import losses as core_losses
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
# ================== Old estimator interface===================================
|
||||
# The estimators below were designed for old feature columns and old estimator
|
||||
# interface. They can be used with new feature columns and losses by setting
|
||||
# use_core_libs = True.
|
||||
|
||||
|
||||
class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
|
||||
"""An estimator using gradient boosted decision trees."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
n_classes=2,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_keys=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
use_core_libs=False,
|
||||
output_leaf_index=False,
|
||||
override_global_step_value=None,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
n_classes: Number of classes in the classification.
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
use_core_libs: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
|
||||
[batch_size, num_trees]. For example, result_iter =
|
||||
classifier.predict(...)
|
||||
for result_dict in result_iter: # access leaf index list by
|
||||
result_dict["leaf_index"] # which contains one leaf index per tree
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This should be used to reset global
|
||||
step to a number > number of steps used to train the current ensemble.
|
||||
For example, the usual way is to train a number of trees and set a very
|
||||
large number of training steps. When the training is done (number of
|
||||
trees were trained), this parameter can be used to set the global step
|
||||
to a large value, making it look like that number of training steps ran.
|
||||
If None, no override of global step will happen.
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
|
||||
Raises:
|
||||
ValueError: If learner_config is not valid.
|
||||
"""
|
||||
if n_classes > 2:
|
||||
# For multi-class classification, use our loss implementation that
|
||||
# supports second order derivative.
|
||||
def loss_fn(labels, logits, weights=None):
|
||||
result = losses.per_example_maxent_loss(
|
||||
labels=labels,
|
||||
logits=logits,
|
||||
weights=weights,
|
||||
num_classes=n_classes)
|
||||
return math_ops.reduce_mean(result[0])
|
||||
else:
|
||||
loss_fn = None
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False,
|
||||
loss_fn=loss_fn,
|
||||
label_keys=label_keys)
|
||||
if learner_config.num_classes == 0:
|
||||
learner_config.num_classes = n_classes
|
||||
elif learner_config.num_classes != n_classes:
|
||||
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
|
||||
(n_classes, learner_config.num_classes))
|
||||
super(GradientBoostedDecisionTreeClassifier, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
'head': head,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'center_bias': center_bias,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'use_core_libs': use_core_libs,
|
||||
'output_leaf_index': output_leaf_index,
|
||||
'override_global_step_value': override_global_step_value,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
|
||||
"""An estimator using gradient boosted decision trees."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
label_dimension=1,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
use_core_libs=False,
|
||||
output_leaf_index=False,
|
||||
override_global_step_value=None,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
label_dimension: Number of regression labels per example. This is the size
|
||||
of the last dimension of the labels `Tensor` (typically, this has shape
|
||||
`[batch_size, label_dimension]`).
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
label_name: String, name of the key in label dict. Can be null if label is
|
||||
a tensor (single headed models).
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
use_core_libs: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. For example, result_dict =
|
||||
classifier.predict(...)
|
||||
for example_prediction_result in result_dict: # access leaf index list
|
||||
by example_prediction_result["leaf_index"] # which contains one leaf
|
||||
index per tree
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This should be used to reset global
|
||||
step to a number > number of steps used to train the current ensemble.
|
||||
For example, the usual way is to train a number of trees and set a very
|
||||
large number of training steps. When the training is done (number of
|
||||
trees were trained), this parameter can be used to set the global step
|
||||
to a large value, making it look like that number of training steps ran.
|
||||
If None, no override of global step will happen.
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
"""
|
||||
head = head_lib.regression_head(
|
||||
label_name=label_name,
|
||||
label_dimension=label_dimension,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False)
|
||||
if label_dimension == 1:
|
||||
learner_config.num_classes = 2
|
||||
else:
|
||||
learner_config.num_classes = label_dimension
|
||||
super(GradientBoostedDecisionTreeRegressor, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
'head': head,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'center_bias': center_bias,
|
||||
'use_core_libs': use_core_libs,
|
||||
'output_leaf_index': False,
|
||||
'override_global_step_value': override_global_step_value,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
|
||||
"""An estimator using gradient boosted decision trees.
|
||||
|
||||
Useful for training with user specified `Head`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
head,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
use_core_libs=False,
|
||||
output_leaf_index=False,
|
||||
override_global_step_value=None,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
head: `Head` instance.
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
use_core_libs: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. For example, result_dict =
|
||||
classifier.predict(...)
|
||||
for example_prediction_result in result_dict: # access leaf index list
|
||||
by example_prediction_result["leaf_index"] # which contains one leaf
|
||||
index per tree
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This should be used to reset global
|
||||
step to a number > number of steps used to train the current ensemble.
|
||||
For example, the usual way is to train a number of trees and set a very
|
||||
large number of training steps. When the training is done (number of
|
||||
trees were trained), this parameter can be used to set the global step
|
||||
to a large value, making it look like that number of training steps ran.
|
||||
If None, no override of global step will happen.
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
"""
|
||||
super(GradientBoostedDecisionTreeEstimator, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
'head': head,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'center_bias': center_bias,
|
||||
'use_core_libs': use_core_libs,
|
||||
'output_leaf_index': False,
|
||||
'override_global_step_value': override_global_step_value,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
class GradientBoostedDecisionTreeRanker(estimator.Estimator):
|
||||
"""A ranking estimator using gradient boosted decision trees."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
head,
|
||||
ranking_model_pair_keys,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_keys=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=False,
|
||||
use_core_libs=False,
|
||||
output_leaf_index=False,
|
||||
override_global_step_value=None,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeRanker instance.
|
||||
|
||||
This is an estimator that can be trained off the pairwise data and can be
|
||||
used for inference on non-paired data. This is essentially LambdaMart.
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
head: `Head` instance.
|
||||
ranking_model_pair_keys: Keys to distinguish between features for left and
|
||||
right part of the training pairs for ranking. For example, for an
|
||||
Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
use_core_libs: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
|
||||
[batch_size, num_trees]. For example, result_iter =
|
||||
classifier.predict(...)
|
||||
for result_dict in result_iter: # access leaf index list by
|
||||
result_dict["leaf_index"] # which contains one leaf index per tree
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This should be used to reset global
|
||||
step to a number > number of steps used to train the current ensemble.
|
||||
For example, the usual way is to train a number of trees and set a very
|
||||
large number of training steps. When the training is done (number of
|
||||
trees were trained), this parameter can be used to set the global step
|
||||
to a large value, making it look like that number of training steps ran.
|
||||
If None, no override of global step will happen.
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
|
||||
Raises:
|
||||
ValueError: If learner_config is not valid.
|
||||
"""
|
||||
super(GradientBoostedDecisionTreeRanker, self).__init__(
|
||||
model_fn=model.ranking_model_builder,
|
||||
params={
|
||||
'head': head,
|
||||
'n_classes': 2,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'center_bias': center_bias,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'use_core_libs': use_core_libs,
|
||||
'output_leaf_index': output_leaf_index,
|
||||
'ranking_model_pair_keys': ranking_model_pair_keys,
|
||||
'override_global_step_value': override_global_step_value,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
# When using this estimator, make sure to regularize the hessian (at least l2,
|
||||
# min_node_weight)!
|
||||
# TODO(nponomareva): extend to take multiple quantiles in one go.
|
||||
class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator):
|
||||
"""An estimator that does quantile regression and returns quantile estimates."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
quantiles,
|
||||
label_dimension=1,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
use_core_libs=False,
|
||||
output_leaf_index=False,
|
||||
override_global_step_value=None,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeQuantileRegressor instance.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
quantiles: a list of quantiles for the loss, each between 0 and 1.
|
||||
label_dimension: Dimension of regression label. This is the size of the
|
||||
last dimension of the labels `Tensor` (typically, this has shape
|
||||
`[batch_size, label_dimension]`). When label_dimension>1, it is
|
||||
recommended to use multiclass strategy diagonal hessian or full hessian.
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
use_core_libs: Whether feature columns and loss are from the core (as
|
||||
opposed to contrib) version of tensorflow.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. For example, result_dict =
|
||||
classifier.predict(...)
|
||||
for example_prediction_result in result_dict: # access leaf index list
|
||||
by example_prediction_result["leaf_index"] # which contains one leaf
|
||||
index per tree
|
||||
override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This should be used to reset global
|
||||
step to a number > number of steps used to train the current ensemble.
|
||||
For example, the usual way is to train a number of trees and set a very
|
||||
large number of training steps. When the training is done (number of
|
||||
trees were trained), this parameter can be used to set the global step
|
||||
to a large value, making it look like that number of training steps ran.
|
||||
If None, no override of global step will happen.
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
"""
|
||||
|
||||
if len(quantiles) > 1:
|
||||
raise ValueError('For now, just one quantile per estimator is supported')
|
||||
|
||||
def _quantile_regression_head(quantile):
|
||||
# Use quantile regression.
|
||||
head = custom_loss_head.CustomLossHead(
|
||||
loss_fn=functools.partial(
|
||||
losses.per_example_quantile_regression_loss, quantile=quantile),
|
||||
link_fn=array_ops.identity,
|
||||
logit_dimension=label_dimension)
|
||||
return head
|
||||
|
||||
learner_config.num_classes = max(2, label_dimension)
|
||||
|
||||
super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__(
|
||||
model_fn=model.model_builder,
|
||||
params={
|
||||
'head': _quantile_regression_head(quantiles[0]),
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'center_bias': center_bias,
|
||||
'use_core_libs': use_core_libs,
|
||||
'output_leaf_index': False,
|
||||
'override_global_step_value': override_global_step_value,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
||||
# ================== New Estimator interface===================================
|
||||
# The estimators below use new core Estimator interface and must be used with
|
||||
# new feature columns and heads.
|
||||
|
||||
|
||||
# For multiclass classification, use the following head since it uses loss
|
||||
# that is twice differentiable.
|
||||
def core_multiclass_head(
|
||||
n_classes,
|
||||
weight_column=None,
|
||||
loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
|
||||
"""Core head for multiclass problems."""
|
||||
|
||||
def loss_fn(labels, logits):
|
||||
result = losses.per_example_maxent_loss(
|
||||
# Don't pass the weights: head already multiplies by them.
|
||||
labels=labels, logits=logits, weights=None, num_classes=n_classes)
|
||||
return result[0]
|
||||
|
||||
# pylint:disable=protected-access
|
||||
head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes=n_classes,
|
||||
loss_fn=loss_fn,
|
||||
loss_reduction=loss_reduction,
|
||||
weight_column=weight_column)
|
||||
# pylint:enable=protected-access
|
||||
|
||||
return head_fn
|
||||
|
||||
|
||||
# For quantile regression, use this head with Core..Estimator, or use
|
||||
# Core..QuantileRegressor directly,
|
||||
def core_quantile_regression_head(
|
||||
quantiles,
|
||||
label_dimension=1,
|
||||
weight_column=None,
|
||||
loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
|
||||
"""Core head for quantile regression problems."""
|
||||
|
||||
def loss_fn(labels, logits):
|
||||
result = losses.per_example_quantile_regression_loss(
|
||||
labels=labels,
|
||||
predictions=logits,
|
||||
# Don't pass the weights: head already multiplies by them.
|
||||
weights=None,
|
||||
quantile=quantiles)
|
||||
return result[0]
|
||||
|
||||
# pylint:disable=protected-access
|
||||
head_fn = core_head_lib._regression_head(
|
||||
label_dimension=label_dimension,
|
||||
loss_fn=loss_fn,
|
||||
loss_reduction=loss_reduction,
|
||||
weight_column=weight_column)
|
||||
# pylint:enable=protected-access
|
||||
return head_fn
|
||||
|
||||
|
||||
class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
|
||||
"""An estimator using gradient boosted decision trees.
|
||||
|
||||
Useful for training with user specified `Head`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
head,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_keys=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
output_leaf_index=False,
|
||||
num_quantiles=100):
|
||||
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
head: `Head` instance.
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. For example, result_dict =
|
||||
classifier.predict(...)
|
||||
for example_prediction_result in result_dict: # access leaf index list
|
||||
by example_prediction_result["leaf_index"] # which contains one leaf
|
||||
index per tree
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
"""
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return model.model_builder(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
config=config,
|
||||
params={
|
||||
'head': head,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'center_bias': center_bias,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'use_core_libs': True,
|
||||
'output_leaf_index': output_leaf_index,
|
||||
'override_global_step_value': None,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
|
||||
|
||||
super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
||||
|
||||
|
||||
class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
|
||||
"""A ranking estimator using gradient boosted decision trees."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
head,
|
||||
ranking_model_pair_keys,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_keys=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=False,
|
||||
output_leaf_index=False,
|
||||
num_quantiles=100):
|
||||
"""Initializes a GradientBoostedDecisionTreeRanker instance.
|
||||
|
||||
This is an estimator that can be trained off the pairwise data and can be
|
||||
used for inference on non-paired data. This is essentially LambdaMart.
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
head: `Head` instance.
|
||||
ranking_model_pair_keys: Keys to distinguish between features for left and
|
||||
right part of the training pairs for ranking. For example, for an
|
||||
Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
|
||||
[batch_size, num_trees]. For example, result_iter =
|
||||
classifier.predict(...)
|
||||
for result_dict in result_iter: # access leaf index list by
|
||||
result_dict["leaf_index"] # which contains one leaf index per tree
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
|
||||
Raises:
|
||||
ValueError: If learner_config is not valid.
|
||||
"""
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return model.ranking_model_builder(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
config=config,
|
||||
params={
|
||||
'head': head,
|
||||
'n_classes': 2,
|
||||
'feature_columns': feature_columns,
|
||||
'learner_config': learner_config,
|
||||
'num_trees': num_trees,
|
||||
'weight_column_name': weight_column_name,
|
||||
'examples_per_layer': examples_per_layer,
|
||||
'center_bias': center_bias,
|
||||
'logits_modifier_function': logits_modifier_function,
|
||||
'use_core_libs': True,
|
||||
'output_leaf_index': output_leaf_index,
|
||||
'ranking_model_pair_keys': ranking_model_pair_keys,
|
||||
'override_global_step_value': None,
|
||||
'num_quantiles': num_quantiles,
|
||||
},
|
||||
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
|
||||
|
||||
super(CoreGradientBoostedDecisionTreeRanker, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
||||
|
||||
|
||||
# When using this estimator, make sure to regularize the hessian (at least l2,
|
||||
# min_node_weight)!
|
||||
# TODO(nponomareva): extend to take multiple quantiles in one go.
|
||||
class CoreGradientBoostedDecisionTreeQuantileRegressor(
|
||||
core_estimator.Estimator):
|
||||
"""An estimator that does quantile regression and returns quantile estimates."""
|
||||
|
||||
def __init__(self,
|
||||
learner_config,
|
||||
examples_per_layer,
|
||||
quantiles,
|
||||
label_dimension=1,
|
||||
num_trees=None,
|
||||
feature_columns=None,
|
||||
weight_column_name=None,
|
||||
model_dir=None,
|
||||
config=None,
|
||||
label_keys=None,
|
||||
feature_engineering_fn=None,
|
||||
logits_modifier_function=None,
|
||||
center_bias=True,
|
||||
output_leaf_index=False,
|
||||
num_quantiles=100):
|
||||
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
|
||||
|
||||
Args:
|
||||
learner_config: A config for the learner.
|
||||
examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
quantiles: a list of quantiles for the loss, each between 0 and 1.
|
||||
label_dimension: Dimension of regression label. This is the size of the
|
||||
last dimension of the labels `Tensor` (typically, this has shape
|
||||
`[batch_size, label_dimension]`). When label_dimension>1, it is
|
||||
recommended to use multiclass strategy diagonal hessian or full hessian.
|
||||
num_trees: An int, number of trees to build.
|
||||
feature_columns: A list of feature columns.
|
||||
weight_column_name: Name of the column for weights, or None if not
|
||||
weighted.
|
||||
model_dir: Directory for model exports, etc.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
label_keys: Optional list of strings with size `[n_classes]` defining the
|
||||
label vocabulary. Only supported for `n_classes` > 2.
|
||||
feature_engineering_fn: Feature engineering function. Takes features and
|
||||
labels which are the output of `input_fn` and returns features and
|
||||
labels which will be fed into the model.
|
||||
logits_modifier_function: A modifier function for the logits.
|
||||
center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
output_leaf_index: whether to output leaf indices along with predictions
|
||||
during inference. The leaf node indexes are available in predictions
|
||||
dict by the key 'leaf_index'. For example, result_dict =
|
||||
classifier.predict(...)
|
||||
for example_prediction_result in result_dict: # access leaf index list
|
||||
by example_prediction_result["leaf_index"] # which contains one leaf
|
||||
index per tree
|
||||
num_quantiles: Number of quantiles to build for numeric feature values.
|
||||
"""
|
||||
if len(quantiles) > 1:
|
||||
raise ValueError('For now, just one quantile per estimator is supported')
|
||||
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return model.model_builder(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
config=config,
|
||||
params={
|
||||
'head':
|
||||
core_quantile_regression_head(
|
||||
quantiles[0],
|
||||
label_dimension=label_dimension,
|
||||
weight_column=weight_column_name),
|
||||
'feature_columns':
|
||||
feature_columns,
|
||||
'learner_config':
|
||||
learner_config,
|
||||
'num_trees':
|
||||
num_trees,
|
||||
'weight_column_name':
|
||||
weight_column_name,
|
||||
'examples_per_layer':
|
||||
examples_per_layer,
|
||||
'center_bias':
|
||||
center_bias,
|
||||
'logits_modifier_function':
|
||||
logits_modifier_function,
|
||||
'use_core_libs':
|
||||
True,
|
||||
'output_leaf_index':
|
||||
output_leaf_index,
|
||||
'override_global_step_value':
|
||||
None,
|
||||
'num_quantiles':
|
||||
num_quantiles,
|
||||
},
|
||||
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
|
||||
|
||||
super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
File diff suppressed because it is too large
Load Diff
@ -1,74 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for converting between core and contrib feature columns."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
|
||||
_CORE_MODE_TO_CONTRIB_MODE_ = {
|
||||
model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN,
|
||||
model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL,
|
||||
model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER
|
||||
}
|
||||
|
||||
|
||||
def _core_mode_to_contrib_mode(mode):
|
||||
return _CORE_MODE_TO_CONTRIB_MODE_[mode]
|
||||
|
||||
|
||||
def _export_outputs_to_output_alternatives(export_outputs):
|
||||
"""Converts EstimatorSpec.export_outputs to output_alternatives.
|
||||
|
||||
Args:
|
||||
export_outputs: export_outputs created by create_estimator_spec.
|
||||
Returns:
|
||||
converted output_alternatives.
|
||||
"""
|
||||
output = {}
|
||||
if export_outputs is not None:
|
||||
for key, value in export_outputs.items():
|
||||
if isinstance(value, export_output.ClassificationOutput):
|
||||
exported_predictions = {
|
||||
prediction_key.PredictionKey.SCORES: value.scores,
|
||||
prediction_key.PredictionKey.CLASSES: value.classes
|
||||
}
|
||||
output[key] = (constants.ProblemType.CLASSIFICATION,
|
||||
exported_predictions)
|
||||
return output
|
||||
return None
|
||||
|
||||
|
||||
def estimator_spec_to_model_fn_ops(estimator_spec, export_alternatives=False):
|
||||
if export_alternatives:
|
||||
alternatives = _export_outputs_to_output_alternatives(
|
||||
estimator_spec.export_outputs)
|
||||
else:
|
||||
alternatives = []
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=_core_mode_to_contrib_mode(estimator_spec.mode),
|
||||
predictions=estimator_spec.predictions,
|
||||
loss=estimator_spec.loss,
|
||||
train_op=estimator_spec.train_op,
|
||||
eval_metric_ops=estimator_spec.eval_metric_ops,
|
||||
output_alternatives=alternatives)
|
@ -1,440 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""GTFlow Model definitions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
|
||||
from tensorflow.contrib.boosted_trees.python.ops import model_ops
|
||||
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.training import training_util
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
|
||||
|
||||
|
||||
class ModelBuilderOutputType(object):
|
||||
MODEL_FN_OPS = 0
|
||||
ESTIMATOR_SPEC = 1
|
||||
|
||||
|
||||
def model_builder(features,
|
||||
labels,
|
||||
mode,
|
||||
params,
|
||||
config,
|
||||
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
|
||||
"""Multi-machine batch gradient descent tree model.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
labels: Labels used to train on.
|
||||
mode: Mode we are in. (TRAIN/EVAL/INFER)
|
||||
params: A dict of hyperparameters.
|
||||
The following hyperparameters are expected:
|
||||
* head: A `Head` instance.
|
||||
* learner_config: A config for the learner.
|
||||
* feature_columns: An iterable containing all the feature columns used by
|
||||
the model.
|
||||
* examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
* weight_column_name: The name of weight column.
|
||||
* center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
* override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
config: `RunConfig` of the estimator.
|
||||
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
|
||||
(new interface).
|
||||
|
||||
Returns:
|
||||
A `ModelFnOps` object.
|
||||
Raises:
|
||||
ValueError: if inputs are not valid.
|
||||
"""
|
||||
head = params["head"]
|
||||
learner_config = params["learner_config"]
|
||||
examples_per_layer = params["examples_per_layer"]
|
||||
feature_columns = params["feature_columns"]
|
||||
weight_column_name = params["weight_column_name"]
|
||||
num_trees = params["num_trees"]
|
||||
use_core_libs = params["use_core_libs"]
|
||||
logits_modifier_function = params["logits_modifier_function"]
|
||||
output_leaf_index = params["output_leaf_index"]
|
||||
override_global_step_value = params.get("override_global_step_value", None)
|
||||
num_quantiles = params["num_quantiles"]
|
||||
|
||||
if features is None:
|
||||
raise ValueError("At least one feature must be specified.")
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing estimator RunConfig.")
|
||||
if config.session_config is not None:
|
||||
session_config = config.session_config
|
||||
session_config.allow_soft_placement = True
|
||||
else:
|
||||
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
config = config.replace(session_config=session_config)
|
||||
|
||||
center_bias = params["center_bias"]
|
||||
|
||||
if isinstance(features, ops.Tensor):
|
||||
features = {features.name: features}
|
||||
|
||||
# Make a shallow copy of features to ensure downstream usage
|
||||
# is unaffected by modifications in the model function.
|
||||
training_features = copy.copy(features)
|
||||
training_features.pop(weight_column_name, None)
|
||||
global_step = training_util.get_global_step()
|
||||
|
||||
initial_ensemble = ""
|
||||
if learner_config.each_tree_start.nodes:
|
||||
if learner_config.each_tree_start_num_layers <= 0:
|
||||
raise ValueError("You must provide each_tree_start_num_layers.")
|
||||
num_layers = learner_config.each_tree_start_num_layers
|
||||
initial_ensemble = """
|
||||
trees { %s }
|
||||
tree_weights: 0.1
|
||||
tree_metadata {
|
||||
num_tree_weight_updates: 1
|
||||
num_layers_grown: %d
|
||||
is_finalized: false
|
||||
}
|
||||
""" % (text_format.MessageToString(
|
||||
learner_config.each_tree_start), num_layers)
|
||||
tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(initial_ensemble, tree_ensemble_proto)
|
||||
initial_ensemble = tree_ensemble_proto.SerializeToString()
|
||||
|
||||
with ops.device(global_step.device):
|
||||
ensemble_handle = model_ops.tree_ensemble_variable(
|
||||
stamp_token=0,
|
||||
tree_ensemble_config=initial_ensemble, # Initialize the ensemble.
|
||||
name="ensemble_model")
|
||||
|
||||
# Create GBDT model.
|
||||
gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
|
||||
is_chief=config.is_chief,
|
||||
num_ps_replicas=config.num_ps_replicas,
|
||||
ensemble_handle=ensemble_handle,
|
||||
center_bias=center_bias,
|
||||
examples_per_layer=examples_per_layer,
|
||||
learner_config=learner_config,
|
||||
feature_columns=feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=training_features,
|
||||
use_core_columns=use_core_libs,
|
||||
output_leaf_index=output_leaf_index,
|
||||
num_quantiles=num_quantiles)
|
||||
with ops.name_scope("gbdt", "gbdt_optimizer"):
|
||||
predictions_dict = gbdt_model.predict(mode)
|
||||
logits = predictions_dict["predictions"]
|
||||
if logits_modifier_function:
|
||||
logits = logits_modifier_function(logits, features, mode)
|
||||
|
||||
def _train_op_fn(loss):
|
||||
"""Returns the op to optimize the loss."""
|
||||
update_op = gbdt_model.train(loss, predictions_dict, labels)
|
||||
with ops.control_dependencies(
|
||||
[update_op]), (ops.colocate_with(global_step)):
|
||||
update_op = state_ops.assign_add(global_step, 1).op
|
||||
return update_op
|
||||
|
||||
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
|
||||
|
||||
training_hooks = []
|
||||
if num_trees:
|
||||
if center_bias:
|
||||
num_trees += 1
|
||||
|
||||
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
|
||||
training_hooks.append(
|
||||
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
|
||||
finalized_trees,
|
||||
override_global_step_value))
|
||||
|
||||
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
if use_core_libs and callable(create_estimator_spec_op):
|
||||
model_fn_ops = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
|
||||
model_fn_ops)
|
||||
else:
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
|
||||
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
|
||||
model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
|
||||
gbdt_batch.LEAF_INDEX]
|
||||
|
||||
model_fn_ops.training_hooks.extend(training_hooks)
|
||||
return model_fn_ops
|
||||
elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
|
||||
assert callable(create_estimator_spec_op)
|
||||
estimator_spec = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
|
||||
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
|
||||
estimator_spec.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
|
||||
gbdt_batch.LEAF_INDEX]
|
||||
|
||||
estimator_spec = estimator_spec._replace(
|
||||
training_hooks=training_hooks + list(estimator_spec.training_hooks))
|
||||
return estimator_spec
|
||||
|
||||
return model_fn_ops
|
||||
|
||||
|
||||
def ranking_model_builder(features,
|
||||
labels,
|
||||
mode,
|
||||
params,
|
||||
config,
|
||||
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
|
||||
"""Multi-machine batch gradient descent tree model for ranking.
|
||||
|
||||
Args:
|
||||
features: `Tensor` or `dict` of `Tensor` objects.
|
||||
labels: Labels used to train on.
|
||||
mode: Mode we are in. (TRAIN/EVAL/INFER)
|
||||
params: A dict of hyperparameters.
|
||||
The following hyperparameters are expected:
|
||||
* head: A `Head` instance.
|
||||
* learner_config: A config for the learner.
|
||||
* feature_columns: An iterable containing all the feature columns used by
|
||||
the model.
|
||||
* examples_per_layer: Number of examples to accumulate before growing a
|
||||
layer. It can also be a function that computes the number of examples
|
||||
based on the depth of the layer that's being built.
|
||||
* weight_column_name: The name of weight column.
|
||||
* center_bias: Whether a separate tree should be created for first fitting
|
||||
the bias.
|
||||
* ranking_model_pair_keys (Optional): Keys to distinguish between features
|
||||
for left and right part of the training pairs for ranking. For example,
|
||||
for an Example with features "a.f1" and "b.f1", the keys would be
|
||||
("a", "b").
|
||||
* override_global_step_value: If after the training is done, global step
|
||||
value must be reset to this value. This is particularly useful for hyper
|
||||
parameter tuning, which can't recognize early stopping due to the number
|
||||
of trees. If None, no override of global step will happen.
|
||||
config: `RunConfig` of the estimator.
|
||||
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
|
||||
(new interface).
|
||||
|
||||
|
||||
Returns:
|
||||
A `ModelFnOps` object.
|
||||
Raises:
|
||||
ValueError: if inputs are not valid.
|
||||
"""
|
||||
head = params["head"]
|
||||
learner_config = params["learner_config"]
|
||||
examples_per_layer = params["examples_per_layer"]
|
||||
feature_columns = params["feature_columns"]
|
||||
weight_column_name = params["weight_column_name"]
|
||||
num_trees = params["num_trees"]
|
||||
use_core_libs = params["use_core_libs"]
|
||||
logits_modifier_function = params["logits_modifier_function"]
|
||||
output_leaf_index = params["output_leaf_index"]
|
||||
ranking_model_pair_keys = params["ranking_model_pair_keys"]
|
||||
override_global_step_value = params.get("override_global_step_value", None)
|
||||
num_quantiles = params["num_quantiles"]
|
||||
|
||||
if features is None:
|
||||
raise ValueError("At least one feature must be specified.")
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing estimator RunConfig.")
|
||||
|
||||
center_bias = params["center_bias"]
|
||||
|
||||
if isinstance(features, ops.Tensor):
|
||||
features = {features.name: features}
|
||||
|
||||
# Make a shallow copy of features to ensure downstream usage
|
||||
# is unaffected by modifications in the model function.
|
||||
training_features = copy.copy(features)
|
||||
training_features.pop(weight_column_name, None)
|
||||
global_step = training_util.get_global_step()
|
||||
with ops.device(global_step.device):
|
||||
ensemble_handle = model_ops.tree_ensemble_variable(
|
||||
stamp_token=0,
|
||||
tree_ensemble_config="", # Initialize an empty ensemble.
|
||||
name="ensemble_model")
|
||||
|
||||
# Extract the features.
|
||||
if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
|
||||
# For ranking pairwise training, we extract two sets of features.
|
||||
if len(ranking_model_pair_keys) != 2:
|
||||
raise ValueError("You must provide keys for ranking.")
|
||||
left_pair_key = ranking_model_pair_keys[0]
|
||||
right_pair_key = ranking_model_pair_keys[1]
|
||||
if left_pair_key is None or right_pair_key is None:
|
||||
raise ValueError("Both pair keys should be provided for ranking.")
|
||||
|
||||
features_1 = {}
|
||||
features_2 = {}
|
||||
for name in training_features:
|
||||
feature = training_features[name]
|
||||
new_name = name[2:]
|
||||
if name.startswith(left_pair_key + "."):
|
||||
features_1[new_name] = feature
|
||||
else:
|
||||
assert name.startswith(right_pair_key + ".")
|
||||
features_2[new_name] = feature
|
||||
|
||||
main_features = features_1
|
||||
supplementary_features = features_2
|
||||
else:
|
||||
# For non-ranking or inference ranking, we have only 1 set of features.
|
||||
main_features = training_features
|
||||
|
||||
# Create GBDT model.
|
||||
gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel(
|
||||
is_chief=config.is_chief,
|
||||
num_ps_replicas=config.num_ps_replicas,
|
||||
ensemble_handle=ensemble_handle,
|
||||
center_bias=center_bias,
|
||||
examples_per_layer=examples_per_layer,
|
||||
learner_config=learner_config,
|
||||
feature_columns=feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=main_features,
|
||||
use_core_columns=use_core_libs,
|
||||
output_leaf_index=output_leaf_index,
|
||||
num_quantiles=num_quantiles)
|
||||
|
||||
with ops.name_scope("gbdt", "gbdt_optimizer"):
|
||||
# Logits for inference.
|
||||
if mode == learn.ModeKeys.INFER:
|
||||
predictions_dict = gbdt_model_main.predict(mode)
|
||||
logits = predictions_dict[gbdt_batch.PREDICTIONS]
|
||||
if logits_modifier_function:
|
||||
logits = logits_modifier_function(logits, features, mode)
|
||||
else:
|
||||
gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel(
|
||||
is_chief=config.is_chief,
|
||||
num_ps_replicas=config.num_ps_replicas,
|
||||
ensemble_handle=ensemble_handle,
|
||||
center_bias=center_bias,
|
||||
examples_per_layer=examples_per_layer,
|
||||
learner_config=learner_config,
|
||||
feature_columns=feature_columns,
|
||||
logits_dimension=head.logits_dimension,
|
||||
features=supplementary_features,
|
||||
use_core_columns=use_core_libs,
|
||||
output_leaf_index=output_leaf_index)
|
||||
|
||||
# Logits for train and eval.
|
||||
if not supplementary_features:
|
||||
raise ValueError("Features for ranking must be specified.")
|
||||
|
||||
predictions_dict_1 = gbdt_model_main.predict(mode)
|
||||
predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS]
|
||||
|
||||
predictions_dict_2 = gbdt_model_supplementary.predict(mode)
|
||||
predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS]
|
||||
|
||||
logits = predictions_1 - predictions_2
|
||||
if logits_modifier_function:
|
||||
logits = logits_modifier_function(logits, features, mode)
|
||||
|
||||
predictions_dict = predictions_dict_1
|
||||
predictions_dict[gbdt_batch.PREDICTIONS] = logits
|
||||
|
||||
def _train_op_fn(loss):
|
||||
"""Returns the op to optimize the loss."""
|
||||
update_op = gbdt_model_main.train(loss, predictions_dict, labels)
|
||||
with ops.control_dependencies(
|
||||
[update_op]), (ops.colocate_with(global_step)):
|
||||
update_op = state_ops.assign_add(global_step, 1).op
|
||||
return update_op
|
||||
|
||||
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
|
||||
|
||||
training_hooks = []
|
||||
if num_trees:
|
||||
if center_bias:
|
||||
num_trees += 1
|
||||
|
||||
finalized_trees, attempted_trees = (
|
||||
gbdt_model_main.get_number_of_trees_tensor())
|
||||
training_hooks.append(
|
||||
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
|
||||
finalized_trees,
|
||||
override_global_step_value))
|
||||
|
||||
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
|
||||
if use_core_libs and callable(create_estimator_spec_op):
|
||||
model_fn_ops = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
|
||||
model_fn_ops)
|
||||
else:
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
|
||||
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
|
||||
model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
|
||||
gbdt_batch.LEAF_INDEX]
|
||||
|
||||
model_fn_ops.training_hooks.extend(training_hooks)
|
||||
return model_fn_ops
|
||||
|
||||
elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
|
||||
assert callable(create_estimator_spec_op)
|
||||
estimator_spec = head.create_estimator_spec(
|
||||
features=features,
|
||||
mode=mode,
|
||||
labels=labels,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits)
|
||||
|
||||
estimator_spec = estimator_spec._replace(
|
||||
training_hooks=training_hooks + list(estimator_spec.training_hooks))
|
||||
return estimator_spec
|
||||
|
||||
return model_fn_ops
|
@ -1,230 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Hooks for use with GTFlow Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import session_run_hook
|
||||
from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs
|
||||
from tensorflow.core.framework.summary_pb2 import Summary
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
|
||||
|
||||
class FeatureImportanceSummarySaver(session_run_hook.SessionRunHook):
|
||||
"""Hook to save feature importance summaries."""
|
||||
|
||||
def __init__(self, model_dir, every_n_steps=1):
|
||||
"""Create a FeatureImportanceSummarySaver Hook.
|
||||
|
||||
This hook creates scalar summaries representing feature importance
|
||||
for each feature column during training.
|
||||
|
||||
Args:
|
||||
model_dir: model base output directory.
|
||||
every_n_steps: frequency, in number of steps, for logging summaries.
|
||||
|
||||
Raises:
|
||||
ValueError: If one of the arguments is invalid.
|
||||
"""
|
||||
if model_dir is None:
|
||||
raise ValueError("model dir must be specified.")
|
||||
self._model_dir = model_dir
|
||||
self._every_n_steps = every_n_steps
|
||||
self._last_triggered_step = None
|
||||
|
||||
def begin(self):
|
||||
self._global_step_tensor = training_util.get_global_step()
|
||||
if self._global_step_tensor is None:
|
||||
raise RuntimeError(
|
||||
"Global step should be created to use FeatureImportanceSummarySaver.")
|
||||
graph = ops.get_default_graph()
|
||||
self._feature_names_tensor = graph.get_tensor_by_name(
|
||||
"gbdt/feature_names:0")
|
||||
self._feature_usage_counts_tensor = graph.get_tensor_by_name(
|
||||
"gbdt/feature_usage_counts:0")
|
||||
self._feature_gains_tensor = graph.get_tensor_by_name(
|
||||
"gbdt/feature_gains:0")
|
||||
|
||||
def before_run(self, run_context):
|
||||
del run_context # Unused by feature importance summary saver hook.
|
||||
requests = {
|
||||
"global_step": self._global_step_tensor,
|
||||
"feature_names": self._feature_names_tensor,
|
||||
"feature_usage_counts": self._feature_usage_counts_tensor,
|
||||
"feature_gains": self._feature_gains_tensor
|
||||
}
|
||||
return SessionRunArgs(requests)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
del run_context # Unused by feature importance summary saver hook.
|
||||
|
||||
# Read result tensors.
|
||||
global_step = run_values.results["global_step"]
|
||||
feature_names = run_values.results["feature_names"]
|
||||
feature_usage_counts = run_values.results["feature_usage_counts"]
|
||||
feature_gains = run_values.results["feature_gains"]
|
||||
|
||||
# Ensure summaries are logged at desired frequency
|
||||
if (self._last_triggered_step is not None and
|
||||
global_step < self._last_triggered_step + self._every_n_steps):
|
||||
return
|
||||
|
||||
# Validate tensors.
|
||||
if (len(feature_names) != len(feature_usage_counts) or
|
||||
len(feature_names) != len(feature_gains)):
|
||||
raise RuntimeError(
|
||||
"Feature names and importance measures have inconsistent lengths.")
|
||||
|
||||
# Compute total usage.
|
||||
total_usage_count = 0.0
|
||||
for usage_count in feature_usage_counts:
|
||||
total_usage_count += usage_count
|
||||
usage_count_norm = 1.0 / total_usage_count if total_usage_count else 1.0
|
||||
|
||||
# Compute total gain.
|
||||
total_gain = 0.0
|
||||
for gain in feature_gains:
|
||||
total_gain += gain
|
||||
gain_norm = 1.0 / total_gain if total_gain else 1.0
|
||||
|
||||
# Output summary for each feature.
|
||||
self._last_triggered_step = global_step
|
||||
for (name, usage_count, gain) in zip(feature_names, feature_usage_counts,
|
||||
feature_gains):
|
||||
output_dir = os.path.join(self._model_dir, name.decode("utf-8"))
|
||||
summary_writer = SummaryWriterCache.get(output_dir)
|
||||
usage_count_summary = Summary(value=[
|
||||
Summary.Value(
|
||||
tag="feature_importance/usage_counts", simple_value=usage_count)
|
||||
])
|
||||
usage_fraction_summary = Summary(value=[
|
||||
Summary.Value(
|
||||
tag="feature_importance/usage_fraction",
|
||||
simple_value=usage_count * usage_count_norm)
|
||||
])
|
||||
summary_writer.add_summary(usage_count_summary, global_step)
|
||||
summary_writer.add_summary(usage_fraction_summary, global_step)
|
||||
gains_summary = Summary(value=[
|
||||
Summary.Value(tag="feature_importance/gains", simple_value=gain)
|
||||
])
|
||||
gains_fraction_summary = Summary(value=[
|
||||
Summary.Value(
|
||||
tag="feature_importance/gains_fraction",
|
||||
simple_value=gain * gain_norm)
|
||||
])
|
||||
summary_writer.add_summary(gains_summary, global_step)
|
||||
summary_writer.add_summary(gains_fraction_summary, global_step)
|
||||
|
||||
|
||||
class FeedFnHook(session_run_hook.SessionRunHook):
|
||||
"""Runs feed_fn and sets the feed_dict accordingly."""
|
||||
|
||||
def __init__(self, feed_fn):
|
||||
self.feed_fn = feed_fn
|
||||
|
||||
def before_run(self, run_context):
|
||||
del run_context # unused by FeedFnHook.
|
||||
return session_run_hook.SessionRunArgs(fetches=None, feed_dict=self.feed_fn)
|
||||
|
||||
|
||||
class StopAfterNTrees(session_run_hook.SessionRunHook):
|
||||
"""Stop training after building N full trees."""
|
||||
|
||||
def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor,
|
||||
override_global_step_value=None):
|
||||
self._num_trees = n
|
||||
# num_attempted_trees_tensor and num_finalized_trees_tensor are both
|
||||
# tensors.
|
||||
self._num_attempted_trees_tensor = num_attempted_trees_tensor
|
||||
self._num_finalized_trees_tensor = num_finalized_trees_tensor
|
||||
self._override_global_step_value = override_global_step_value
|
||||
|
||||
def begin(self):
|
||||
self._global_step_tensor = training_util.get_global_step()
|
||||
if self._global_step_tensor is None:
|
||||
raise RuntimeError("Global step should be created.")
|
||||
|
||||
if self._override_global_step_value is not None:
|
||||
self._override_global_step_op = state_ops.assign(
|
||||
self._global_step_tensor, self._override_global_step_value)
|
||||
|
||||
def before_run(self, run_context):
|
||||
del run_context # unused by StopTrainingAfterNTrees.
|
||||
return session_run_hook.SessionRunArgs({
|
||||
"num_attempted_trees": self._num_attempted_trees_tensor,
|
||||
"num_finalized_trees": self._num_finalized_trees_tensor,
|
||||
})
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
num_attempted_trees = run_values.results["num_attempted_trees"]
|
||||
num_finalized_trees = run_values.results["num_finalized_trees"]
|
||||
assert num_attempted_trees is not None
|
||||
assert num_finalized_trees is not None
|
||||
# Stop when the required number of finalized trees is reached, or when we
|
||||
# try enough times to build a tree but keep failing.
|
||||
if (num_finalized_trees >= self._num_trees or
|
||||
num_attempted_trees > 2 * self._num_trees):
|
||||
logging.info("Requesting stop since we have reached %d trees.",
|
||||
num_finalized_trees)
|
||||
if self._override_global_step_value is not None:
|
||||
logging.info("Overriding global steps value.")
|
||||
run_context.session.run(self._override_global_step_op)
|
||||
run_context.request_stop()
|
||||
|
||||
|
||||
class SwitchTrainOp(session_run_hook.SessionRunHook):
|
||||
"""Hook that switches the train op after specified number of steps.
|
||||
|
||||
Hook that replaces the train op depending on the number of steps of training
|
||||
that have taken place. The first_train_op is used till train_steps steps
|
||||
are reached. Thereafter the second_train_op is used.
|
||||
"""
|
||||
|
||||
def __init__(self, first_train_op, train_steps, second_train_op):
|
||||
"""Initializes a `SwitchTrainOp`."""
|
||||
self._first_train_op = first_train_op
|
||||
self._second_train_op = second_train_op
|
||||
self._train_steps = train_steps
|
||||
|
||||
def _get_train_op_for_global_step(self, current_step):
|
||||
"""Gets train_op for current global step."""
|
||||
if current_step < self._train_steps:
|
||||
return self._first_train_op
|
||||
return self._second_train_op
|
||||
|
||||
def begin(self):
|
||||
self._global_step_tensor = training_util.get_global_step()
|
||||
self._current_train_op = control_flow_ops.no_op()
|
||||
if self._global_step_tensor is None:
|
||||
raise RuntimeError(
|
||||
"Global step should be created to use SwitchTrainOp.")
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
return session_run_hook.SessionRunArgs(
|
||||
{"global_step": self._global_step_tensor,
|
||||
"train_op": self._current_train_op})
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
self._current_train_op = self._get_train_op_for_global_step(
|
||||
run_values.results["global_step"])
|
@ -1,76 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for trainer hooks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
class FeatureImportanceSummarySaverTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
trainer_hooks.FeatureImportanceSummarySaver(model_dir=None)
|
||||
|
||||
def test_invalid_graph(self):
|
||||
# Create inputs.
|
||||
model_dir = tempfile.mkdtemp()
|
||||
hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
|
||||
with ops.Graph().as_default():
|
||||
# Begin won't be able to find the required tensors in the graph.
|
||||
_ = variables.get_or_create_global_step()
|
||||
with self.assertRaises(KeyError):
|
||||
hook.begin()
|
||||
|
||||
def test_run(self):
|
||||
# Create inputs.
|
||||
model_dir = tempfile.mkdtemp()
|
||||
hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
|
||||
with ops.Graph().as_default(), tf_session.Session() as sess:
|
||||
global_step = variables.get_or_create_global_step()
|
||||
with ops.name_scope("gbdt"):
|
||||
constant_op.constant(["featA", "featB"], name="feature_names")
|
||||
constant_op.constant([0, 2], name="feature_usage_counts")
|
||||
constant_op.constant([0, 0.8], name="feature_gains")
|
||||
# Begin finds tensors in the graph.
|
||||
hook.begin()
|
||||
sess.run(tf_variables.global_variables_initializer())
|
||||
# Run hook in a monitored session.
|
||||
train_op = state_ops.assign_add(global_step, 1)
|
||||
mon_sess = monitored_session._HookedSession(sess, [hook])
|
||||
mon_sess.run(train_op)
|
||||
hook.end(sess)
|
||||
# Ensure output summary dirs are created.
|
||||
self.assertTrue(os.path.exists(os.path.join(model_dir, "featA")))
|
||||
self.assertTrue(os.path.exists(os.path.join(model_dir, "featB")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -1,169 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Demonstrates multiclass MNIST TF Boosted trees example.
|
||||
|
||||
This example demonstrates how to run experiments with TF Boosted Trees on
|
||||
a binary dataset. We use digits 4 and 9 from the original MNIST dataset.
|
||||
|
||||
Example Usage:
|
||||
python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \
|
||||
--output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \
|
||||
--batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \
|
||||
--num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1
|
||||
|
||||
When training is done, accuracy on eval data is reported. Point tensorboard
|
||||
to the directory for the run to see how the training progresses:
|
||||
|
||||
tensorboard --logdir=/tmp/binary_mnist
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.learn import learn_runner
|
||||
|
||||
|
||||
def get_input_fn(data,
|
||||
batch_size,
|
||||
capacity=10000,
|
||||
min_after_dequeue=3000):
|
||||
"""Input function over MNIST data."""
|
||||
# Keep only 4 and 9 digits.
|
||||
ids = np.where((data.labels == 4) | (data.labels == 9))
|
||||
images = data.images[ids]
|
||||
labels = data.labels[ids]
|
||||
# Make digit 4 label 1, 9 is 0.
|
||||
labels = labels == 4
|
||||
|
||||
def _input_fn():
|
||||
"""Prepare features and labels."""
|
||||
images_batch, labels_batch = tf.train.shuffle_batch(
|
||||
tensors=[images,
|
||||
labels.astype(np.int32)],
|
||||
batch_size=batch_size,
|
||||
capacity=capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
enqueue_many=True,
|
||||
num_threads=4)
|
||||
features_map = {"images": images_batch}
|
||||
return features_map, labels_batch
|
||||
|
||||
return _input_fn
|
||||
|
||||
|
||||
# Main config - creates a TF Boosted Trees Estimator based on flags.
|
||||
def _get_tfbt(output_dir):
|
||||
"""Configures TF Boosted Trees estimator based on flags."""
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
|
||||
learner_config.regularization.l1 = 0.0
|
||||
learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
|
||||
learner_config.constraints.max_tree_depth = FLAGS.depth
|
||||
|
||||
growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
|
||||
learner_config.growing_mode = growing_mode
|
||||
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
|
||||
|
||||
# Create a TF Boosted trees estimator that can take in custom loss.
|
||||
estimator = GradientBoostedDecisionTreeClassifier(
|
||||
learner_config=learner_config,
|
||||
examples_per_layer=FLAGS.examples_per_layer,
|
||||
model_dir=output_dir,
|
||||
num_trees=FLAGS.num_trees,
|
||||
center_bias=False,
|
||||
config=run_config)
|
||||
return estimator
|
||||
|
||||
|
||||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for gradient boosted decision trees."""
|
||||
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||
train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
|
||||
eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
|
||||
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_tfbt(output_dir),
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_steps=None,
|
||||
eval_steps=FLAGS.num_eval_steps,
|
||||
eval_metrics=None)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
learn_runner.run(
|
||||
experiment_fn=_make_experiment_fn,
|
||||
output_dir=FLAGS.output_dir,
|
||||
schedule="train_and_evaluate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
# Define the list of flags that users can change.
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Choose the dir for the output.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch size for reading data.")
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Size of the batch for eval.")
|
||||
parser.add_argument(
|
||||
"--num_eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of steps to run evaluation for.")
|
||||
# Flags for gradient boosted trees config.
|
||||
parser.add_argument(
|
||||
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
|
||||
parser.add_argument(
|
||||
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Learning rate (shrinkage weight) with which each new tree is added."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--examples_per_layer",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of examples to accumulate stats for per layer.")
|
||||
parser.add_argument(
|
||||
"--num_trees",
|
||||
type=int,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Number of trees to grow before stopping.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,171 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Demonstrates a regression on Boston housing data.
|
||||
|
||||
This example demonstrates how to run experiments with TF Boosted Trees on
|
||||
a regression dataset. We split all the data into 20% test and 80% train,
|
||||
and are using l2 loss and l2 regularization.
|
||||
|
||||
Example Usage:
|
||||
|
||||
python tensorflow/contrib/boosted_trees/examples/boston.py \
|
||||
--batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \
|
||||
--num_eval_steps=1 --num_trees=500 --l2=0.001 \
|
||||
--vmodule=training_ops=1
|
||||
|
||||
When training is done, mean squared error on eval data is reported.
|
||||
Point tensorboard to the directory for the run to see how the training
|
||||
progresses:
|
||||
|
||||
tensorboard --logdir=/tmp/boston
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
from tensorflow.contrib.learn import learn_runner
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_BOSTON_NUM_FEATURES = 13
|
||||
|
||||
|
||||
# Main config - creates a TF Boosted Trees Estimator based on flags.
|
||||
def _get_tfbt(output_dir, feature_cols):
|
||||
"""Configures TF Boosted Trees estimator based on flags."""
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
|
||||
learner_config.regularization.l1 = 0.0
|
||||
learner_config.regularization.l2 = FLAGS.l2
|
||||
learner_config.constraints.max_tree_depth = FLAGS.depth
|
||||
|
||||
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
|
||||
|
||||
# Create a TF Boosted trees regression estimator.
|
||||
estimator = GradientBoostedDecisionTreeRegressor(
|
||||
learner_config=learner_config,
|
||||
# This should be the number of examples. For large datasets it can be
|
||||
# larger than the batch_size.
|
||||
examples_per_layer=FLAGS.batch_size,
|
||||
feature_columns=feature_cols,
|
||||
label_dimension=1,
|
||||
model_dir=output_dir,
|
||||
num_trees=FLAGS.num_trees,
|
||||
center_bias=False,
|
||||
config=run_config)
|
||||
return estimator
|
||||
|
||||
|
||||
def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float,
|
||||
num_sparse_int, export_dir, unused_eval_result):
|
||||
universal_format = custom_export_strategy.convert_to_universal_format(
|
||||
dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int)
|
||||
with tf.gfile.GFile(os.path.join(
|
||||
compat.as_bytes(export_dir), compat.as_bytes("tree_proto")), "w") as f:
|
||||
f.write(str(universal_format))
|
||||
|
||||
|
||||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for gradient boosted decision trees."""
|
||||
(x_train, y_train), (x_test,
|
||||
y_test) = tf.keras.datasets.boston_housing.load_data()
|
||||
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_train},
|
||||
y=y_train,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
|
||||
feature_columns = [
|
||||
feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES)
|
||||
]
|
||||
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
|
||||
feature_columns)
|
||||
serving_input_fn = tf.contrib.learn.utils.build_parsing_serving_input_fn(
|
||||
feature_spec)
|
||||
# An export strategy that outputs the feature importance and also exports
|
||||
# the internal tree representation in another format.
|
||||
export_strategy = custom_export_strategy.make_custom_export_strategy(
|
||||
"exports",
|
||||
convert_fn=_convert_fn,
|
||||
feature_columns=feature_columns,
|
||||
export_input_fn=serving_input_fn)
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_tfbt(output_dir, feature_columns),
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_steps=None,
|
||||
eval_steps=FLAGS.num_eval_steps,
|
||||
eval_metrics=None,
|
||||
export_strategies=[export_strategy])
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
learn_runner.run(
|
||||
experiment_fn=_make_experiment_fn,
|
||||
output_dir=FLAGS.output_dir,
|
||||
schedule="train_and_evaluate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
# Define the list of flags that users can change.
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch size for reading data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Choose the dir for the output.")
|
||||
parser.add_argument(
|
||||
"--num_eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of steps to run evaluation for.")
|
||||
# Flags for gradient boosted trees config.
|
||||
parser.add_argument(
|
||||
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
|
||||
parser.add_argument(
|
||||
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Learning rate (shrinkage weight) with which each new tree is added."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_trees",
|
||||
type=int,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Number of trees to grow before stopping.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,165 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Regression on Boston housing data using DNNBoostedTreeCombinedRegressor.
|
||||
|
||||
Example Usage:
|
||||
|
||||
python tensorflow/contrib/boosted_trees/examples/boston_combined.py \
|
||||
--batch_size=404 --output_dir="/tmp/boston" \
|
||||
--dnn_hidden_units="8,4" --dnn_steps_to_train=1000 \
|
||||
--tree_depth=4 --tree_learning_rate=0.1 \
|
||||
--num_trees=100 --tree_l2=0.001 --num_eval_steps=1 \
|
||||
--vmodule=training_ops=1
|
||||
|
||||
When training is done, mean squared error on eval data is reported.
|
||||
Point tensorboard to the directory for the run to see how the training
|
||||
progresses:
|
||||
|
||||
tensorboard --logdir=/tmp/boston
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch.dnn_tree_combined_estimator import DNNBoostedTreeCombinedRegressor
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.layers.python.layers import feature_column
|
||||
from tensorflow.contrib.learn.python.learn import learn_runner
|
||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
|
||||
|
||||
_BOSTON_NUM_FEATURES = 13
|
||||
|
||||
|
||||
def _get_estimator(output_dir, feature_cols):
|
||||
"""Configures DNNBoostedTreeCombinedRegressor based on flags."""
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = (
|
||||
FLAGS.tree_learning_rate)
|
||||
learner_config.regularization.l1 = 0.0
|
||||
learner_config.regularization.l2 = FLAGS.tree_l2
|
||||
learner_config.constraints.max_tree_depth = FLAGS.tree_depth
|
||||
|
||||
run_config = tf.contrib.learn.RunConfig(save_summary_steps=1)
|
||||
|
||||
# Create a DNNBoostedTreeCombinedRegressor estimator.
|
||||
estimator = DNNBoostedTreeCombinedRegressor(
|
||||
dnn_hidden_units=[int(x) for x in FLAGS.dnn_hidden_units.split(",")],
|
||||
dnn_feature_columns=feature_cols,
|
||||
tree_learner_config=learner_config,
|
||||
num_trees=FLAGS.num_trees,
|
||||
# This should be the number of examples. For large datasets it can be
|
||||
# larger than the batch_size.
|
||||
tree_examples_per_layer=FLAGS.batch_size,
|
||||
model_dir=output_dir,
|
||||
config=run_config,
|
||||
dnn_input_layer_to_tree=True,
|
||||
dnn_steps_to_train=FLAGS.dnn_steps_to_train)
|
||||
return estimator
|
||||
|
||||
|
||||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for DNNBoostedTreeCombinedRegressor."""
|
||||
(x_train, y_train), (x_test,
|
||||
y_test) = tf.keras.datasets.boston_housing.load_data()
|
||||
|
||||
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_train},
|
||||
y=y_train,
|
||||
batch_size=FLAGS.batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
|
||||
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
|
||||
feature_columns = [
|
||||
feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES)
|
||||
]
|
||||
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
|
||||
feature_columns)
|
||||
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
|
||||
export_strategies = [
|
||||
saved_model_export_utils.make_export_strategy(serving_input_fn)]
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_estimator(output_dir, feature_columns),
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_steps=None,
|
||||
eval_steps=FLAGS.num_eval_steps,
|
||||
eval_metrics=None,
|
||||
export_strategies=export_strategies)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
learn_runner.run(
|
||||
experiment_fn=_make_experiment_fn,
|
||||
output_dir=FLAGS.output_dir,
|
||||
schedule="train_and_evaluate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
# Define the list of flags that users can change.
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch size for reading data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Choose the dir for the output.")
|
||||
parser.add_argument(
|
||||
"--num_eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of steps to run evaluation for.")
|
||||
# Flags for configuring DNNBoostedTreeCombinedRegressor.
|
||||
parser.add_argument(
|
||||
"--dnn_hidden_units",
|
||||
type=str,
|
||||
default="8,4",
|
||||
help="Hidden layers for DNN.")
|
||||
parser.add_argument(
|
||||
"--dnn_steps_to_train",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of steps to train DNN.")
|
||||
parser.add_argument(
|
||||
"--tree_depth", type=int, default=4, help="Maximum depth of trees.")
|
||||
parser.add_argument(
|
||||
"--tree_l2", type=float, default=1.0, help="l2 regularization per batch.")
|
||||
parser.add_argument(
|
||||
"--tree_learning_rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help=("Learning rate (shrinkage weight) with which each "
|
||||
"new tree is added."))
|
||||
parser.add_argument(
|
||||
"--num_trees",
|
||||
type=int,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Number of trees to grow before stopping.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,171 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""Demonstrates multiclass MNIST TF Boosted trees example.
|
||||
|
||||
This example demonstrates how to run experiments with TF Boosted Trees on
|
||||
a MNIST dataset. We are using layer by layer boosting with diagonal hessian
|
||||
strategy for multiclass handling, and cross entropy loss.
|
||||
|
||||
Example Usage:
|
||||
python tensorflow/contrib/boosted_trees/examples/mnist.py \
|
||||
--output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000 \
|
||||
--examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \
|
||||
--num_trees=10 --l2=1 --vmodule=training_ops=1
|
||||
|
||||
When training is done, accuracy on eval data is reported. Point tensorboard
|
||||
to the directory for the run to see how the training progresses:
|
||||
|
||||
tensorboard --logdir=/tmp/mnist
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.learn import learn_runner
|
||||
|
||||
|
||||
def get_input_fn(dataset_split,
|
||||
batch_size,
|
||||
capacity=10000,
|
||||
min_after_dequeue=3000):
|
||||
"""Input function over MNIST data."""
|
||||
|
||||
def _input_fn():
|
||||
"""Prepare features and labels."""
|
||||
images_batch, labels_batch = tf.train.shuffle_batch(
|
||||
tensors=[dataset_split.images,
|
||||
dataset_split.labels.astype(np.int32)],
|
||||
batch_size=batch_size,
|
||||
capacity=capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
enqueue_many=True,
|
||||
num_threads=4)
|
||||
features_map = {"images": images_batch}
|
||||
return features_map, labels_batch
|
||||
|
||||
return _input_fn
|
||||
|
||||
|
||||
# Main config - creates a TF Boosted Trees Estimator based on flags.
|
||||
def _get_tfbt(output_dir):
|
||||
"""Configures TF Boosted Trees estimator based on flags."""
|
||||
learner_config = learner_pb2.LearnerConfig()
|
||||
|
||||
num_classes = 10
|
||||
|
||||
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
|
||||
learner_config.num_classes = num_classes
|
||||
learner_config.regularization.l1 = 0.0
|
||||
learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
|
||||
learner_config.constraints.max_tree_depth = FLAGS.depth
|
||||
|
||||
growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
|
||||
learner_config.growing_mode = growing_mode
|
||||
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
|
||||
|
||||
learner_config.multi_class_strategy = (
|
||||
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
|
||||
|
||||
# Create a TF Boosted trees estimator that can take in custom loss.
|
||||
estimator = GradientBoostedDecisionTreeClassifier(
|
||||
learner_config=learner_config,
|
||||
n_classes=num_classes,
|
||||
examples_per_layer=FLAGS.examples_per_layer,
|
||||
model_dir=output_dir,
|
||||
num_trees=FLAGS.num_trees,
|
||||
center_bias=False,
|
||||
config=run_config)
|
||||
return estimator
|
||||
|
||||
|
||||
def _make_experiment_fn(output_dir):
|
||||
"""Creates experiment for gradient boosted decision trees."""
|
||||
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||
train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
|
||||
eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
|
||||
|
||||
return tf.contrib.learn.Experiment(
|
||||
estimator=_get_tfbt(output_dir),
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
train_steps=None,
|
||||
eval_steps=FLAGS.num_eval_steps,
|
||||
eval_metrics=None)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
learn_runner.run(
|
||||
experiment_fn=_make_experiment_fn,
|
||||
output_dir=FLAGS.output_dir,
|
||||
schedule="train_and_evaluate")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
# Define the list of flags that users can change.
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Choose the dir for the output.")
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch size for reading data.")
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Size of the batch for eval.")
|
||||
parser.add_argument(
|
||||
"--num_eval_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of steps to run evaluation for.")
|
||||
# Flags for gradient boosted trees config.
|
||||
parser.add_argument(
|
||||
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
|
||||
parser.add_argument(
|
||||
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Learning rate (shrinkage weight) with which each new tree is added."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--examples_per_layer",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of examples to accumulate stats for per layer.")
|
||||
parser.add_argument(
|
||||
"--num_trees",
|
||||
type=int,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Number of trees to grow before stopping.")
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,212 +0,0 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
|
||||
using boosted_trees::models::DecisionTreeEnsembleResource;
|
||||
|
||||
// Creates a tree ensemble variable.
|
||||
class CreateTreeEnsembleVariableOp : public OpKernel {
|
||||
public:
|
||||
explicit CreateTreeEnsembleVariableOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Get the tree ensemble config.
|
||||
const Tensor* tree_ensemble_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
|
||||
&tree_ensemble_config_t));
|
||||
auto* result = new DecisionTreeEnsembleResource();
|
||||
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<tstring>()(),
|
||||
stamp_token)) {
|
||||
result->Unref();
|
||||
OP_REQUIRES(
|
||||
context, false,
|
||||
errors::InvalidArgument("Unable to parse tree ensemble config."));
|
||||
}
|
||||
|
||||
// Only create one, if one does not exist already. Report status for all
|
||||
// other exceptions.
|
||||
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
OP_REQUIRES(context, false, status);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Op for retrieving a model stamp token without having to serialize.
|
||||
class TreeEnsembleStampTokenOp : public OpKernel {
|
||||
public:
|
||||
explicit TreeEnsembleStampTokenOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
&output_stamp_token_t));
|
||||
output_stamp_token_t->scalar<int64>()() = ensemble_resource->stamp();
|
||||
}
|
||||
};
|
||||
|
||||
// Op for serializing a model.
|
||||
class TreeEnsembleSerializeOp : public OpKernel {
|
||||
public:
|
||||
explicit TreeEnsembleSerializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
Tensor* output_stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
|
||||
&output_stamp_token_t));
|
||||
output_stamp_token_t->scalar<int64>()() = ensemble_resource->stamp();
|
||||
Tensor* output_config_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(1, TensorShape(), &output_config_t));
|
||||
output_config_t->scalar<tstring>()() =
|
||||
ensemble_resource->SerializeAsString();
|
||||
}
|
||||
};
|
||||
|
||||
// Op for deserializing a tree ensemble variable from a checkpoint.
|
||||
class TreeEnsembleDeserializeOp : public OpKernel {
|
||||
public:
|
||||
explicit TreeEnsembleDeserializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Get the tree ensemble config.
|
||||
const Tensor* tree_ensemble_config_t;
|
||||
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
|
||||
&tree_ensemble_config_t));
|
||||
// Deallocate all the previous objects on the resource.
|
||||
ensemble_resource->Reset();
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ensemble_resource->InitFromSerialized(
|
||||
tree_ensemble_config_t->scalar<tstring>()(), stamp_token),
|
||||
errors::InvalidArgument("Unable to parse tree ensemble config."));
|
||||
}
|
||||
};
|
||||
|
||||
class TreeEnsembleUsedHandlersOp : public OpKernel {
|
||||
public:
|
||||
explicit TreeEnsembleUsedHandlersOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("num_all_handlers", &num_handlers_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Only the Chief should run this Op and it is guaranteed to be in
|
||||
// a consistent state so the stamps must always match.
|
||||
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
|
||||
|
||||
Tensor* output_used_handlers_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("used_handlers_mask", {num_handlers_},
|
||||
&output_used_handlers_t));
|
||||
auto output_used_handlers = output_used_handlers_t->vec<bool>();
|
||||
|
||||
Tensor* output_num_used_handlers_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_used_handlers", {},
|
||||
&output_num_used_handlers_t));
|
||||
int handler_idx = 0;
|
||||
std::vector<int64> used_handlers = ensemble_resource->GetUsedHandlers();
|
||||
output_num_used_handlers_t->scalar<int64>()() = used_handlers.size();
|
||||
for (int64 i = 0; i < num_handlers_; ++i) {
|
||||
if (handler_idx >= used_handlers.size() ||
|
||||
used_handlers[handler_idx] > i) {
|
||||
output_used_handlers(i) = false;
|
||||
} else {
|
||||
OP_REQUIRES(context, used_handlers[handler_idx] == i,
|
||||
errors::InvalidArgument("Handler IDs should be sorted."));
|
||||
++handler_idx;
|
||||
output_used_handlers(i) = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_handlers_;
|
||||
};
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<DecisionTreeEnsembleResource>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
|
||||
CreateTreeEnsembleVariableOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStampToken").Device(DEVICE_CPU),
|
||||
TreeEnsembleStampTokenOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleSerialize").Device(DEVICE_CPU),
|
||||
TreeEnsembleSerializeOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleDeserialize").Device(DEVICE_CPU),
|
||||
TreeEnsembleDeserializeOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleUsedHandlers").Device(DEVICE_CPU),
|
||||
TreeEnsembleUsedHandlersOp);
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -1,468 +0,0 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::boosted_trees::learner::AveragingConfig;
|
||||
using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
|
||||
using boosted_trees::learner::LearnerConfig;
|
||||
using boosted_trees::learner::LearningRateConfig;
|
||||
using boosted_trees::learner::LearningRateDropoutDrivenConfig;
|
||||
using boosted_trees::models::DecisionTreeEnsembleResource;
|
||||
using boosted_trees::models::MultipleAdditiveTrees;
|
||||
using boosted_trees::utils::DropoutUtils;
|
||||
using boosted_trees::utils::TensorUtils;
|
||||
|
||||
namespace {
|
||||
const char* kLearnerConfigAttributeName = "learner_config";
|
||||
const char* kSeedTensorName = "seed";
|
||||
const char* kApplyDropoutAttributeName = "apply_dropout";
|
||||
const char* kApplyAveragingAttributeName = "apply_averaging";
|
||||
const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights";
|
||||
const char* kPredictionsTensorName = "predictions";
|
||||
const char* kLeafIndexTensorName = "leaf_index";
|
||||
|
||||
void CalculateTreesToInclude(
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
|
||||
const std::vector<int32>& trees_to_drop, const int32 num_trees,
|
||||
const bool only_finalized, const bool center_bias,
|
||||
std::vector<int32>* trees_to_include) {
|
||||
trees_to_include->reserve(num_trees - trees_to_drop.size());
|
||||
|
||||
int32 index = 0;
|
||||
// This assumes that trees_to_drop is a sorted list of tree ids.
|
||||
for (int32 tree = 0; tree < num_trees; ++tree) {
|
||||
// Skip the tree if tree is in the list of trees_to_drop.
|
||||
if (!trees_to_drop.empty() && index < trees_to_drop.size() &&
|
||||
trees_to_drop[index] == tree) {
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
// Or skip if the tree is not finalized and only_finalized is set,
|
||||
// with the exception of centering bias.
|
||||
if (only_finalized && !(center_bias && tree == 0) &&
|
||||
config.tree_metadata_size() > 0 &&
|
||||
!config.tree_metadata(tree).is_finalized()) {
|
||||
continue;
|
||||
}
|
||||
trees_to_include->push_back(tree);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class GradientTreesPredictionOp : public OpKernel {
|
||||
public:
|
||||
explicit GradientTreesPredictionOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_locking", &use_locking_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("center_bias", ¢er_bias_));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr(kApplyDropoutAttributeName, &apply_dropout_));
|
||||
|
||||
LearnerConfig learner_config;
|
||||
string learner_config_str;
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kLearnerConfigAttributeName,
|
||||
&learner_config_str));
|
||||
OP_REQUIRES(
|
||||
context, ParseProtoUnlimited(&learner_config, learner_config_str),
|
||||
errors::InvalidArgument("Unable to parse learner config config."));
|
||||
|
||||
num_classes_ = learner_config.num_classes();
|
||||
OP_REQUIRES(context, num_classes_ >= 2,
|
||||
errors::InvalidArgument("Number of classes must be >=2"));
|
||||
OP_REQUIRES(
|
||||
context, ParseProtoUnlimited(&learner_config, learner_config_str),
|
||||
errors::InvalidArgument("Unable to parse learner config config."));
|
||||
|
||||
bool reduce_dim;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("reduce_dim", &reduce_dim));
|
||||
prediction_vector_size_ = reduce_dim ? num_classes_ - 1 : num_classes_;
|
||||
|
||||
only_finalized_trees_ =
|
||||
learner_config.growing_mode() == learner_config.WHOLE_TREE;
|
||||
if (learner_config.has_learning_rate_tuner() &&
|
||||
learner_config.learning_rate_tuner().tuner_case() ==
|
||||
LearningRateConfig::kDropout) {
|
||||
dropout_config_ = learner_config.learning_rate_tuner().dropout();
|
||||
has_dropout_ = true;
|
||||
} else {
|
||||
has_dropout_ = false;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kApplyAveragingAttributeName,
|
||||
&apply_averaging_));
|
||||
apply_averaging_ =
|
||||
apply_averaging_ && learner_config.averaging_config().config_case() !=
|
||||
AveragingConfig::CONFIG_NOT_SET;
|
||||
if (apply_averaging_) {
|
||||
averaging_config_ = learner_config.averaging_config();
|
||||
|
||||
// If there is averaging config, check that the values are correct.
|
||||
switch (averaging_config_.config_case()) {
|
||||
case AveragingConfig::kAverageLastNTreesFieldNumber: {
|
||||
OP_REQUIRES(context, averaging_config_.average_last_n_trees() > 0,
|
||||
errors::InvalidArgument(
|
||||
"Average last n trees must be a positive number"));
|
||||
break;
|
||||
}
|
||||
case AveragingConfig::kAverageLastPercentTreesFieldNumber: {
|
||||
OP_REQUIRES(context,
|
||||
averaging_config_.average_last_percent_trees() > 0 &&
|
||||
averaging_config_.average_last_percent_trees() <= 1.0,
|
||||
errors::InvalidArgument(
|
||||
"Average last percent must be in (0,1] interval."));
|
||||
break;
|
||||
}
|
||||
case AveragingConfig::CONFIG_NOT_SET: {
|
||||
LOG(QFATAL) << "We should never get here.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
// Gets the resource. Grabs the mutex but releases it.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
// Release the reference to the resource once we're done using it.
|
||||
if (use_locking_) {
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
DoCompute(context, ensemble_resource,
|
||||
/*return_output_leaf_index=*/false);
|
||||
} else {
|
||||
DoCompute(context, ensemble_resource,
|
||||
/*return_output_leaf_index=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// return_output_leaf_index is a boolean variable indicating whether to output
|
||||
// leaf index in prediction. Though this class invokes only with this param
|
||||
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
|
||||
// with the true value.
|
||||
virtual void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const bool return_output_leaf_index) {
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||
context, &dense_float_features_list));
|
||||
|
||||
// Read sparse float features list;
|
||||
OpInputList sparse_float_feature_indices_list;
|
||||
OpInputList sparse_float_feature_values_list;
|
||||
OpInputList sparse_float_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
|
||||
context, &sparse_float_feature_indices_list,
|
||||
&sparse_float_feature_values_list,
|
||||
&sparse_float_feature_shapes_list));
|
||||
|
||||
// Read sparse int features list;
|
||||
OpInputList sparse_int_feature_indices_list;
|
||||
OpInputList sparse_int_feature_values_list;
|
||||
OpInputList sparse_int_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseIntFeatures(
|
||||
context, &sparse_int_feature_indices_list,
|
||||
&sparse_int_feature_values_list,
|
||||
&sparse_int_feature_shapes_list));
|
||||
|
||||
// Infer batch size.
|
||||
const int64 batch_size = TensorUtils::InferBatchSize(
|
||||
dense_float_features_list, sparse_float_feature_shapes_list,
|
||||
sparse_int_feature_shapes_list);
|
||||
|
||||
// Read batch features.
|
||||
boosted_trees::utils::BatchFeatures batch_features(batch_size);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
batch_features.Initialize(
|
||||
TensorUtils::OpInputListToTensorVec(dense_float_features_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_indices_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_values_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_shapes_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_int_feature_indices_list),
|
||||
TensorUtils::OpInputListToTensorVec(sparse_int_feature_values_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_int_feature_shapes_list)));
|
||||
|
||||
std::vector<int32> dropped_trees;
|
||||
std::vector<float> original_weights;
|
||||
|
||||
// Do dropout if needed.
|
||||
if (apply_dropout_ && has_dropout_) {
|
||||
// Read in seed and cast to uint64.
|
||||
const Tensor* seed_t;
|
||||
OP_REQUIRES_OK(context, context->input(kSeedTensorName, &seed_t));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_t->shape()),
|
||||
errors::InvalidArgument("Seed must be a scalar."));
|
||||
const uint64 seed = seed_t->scalar<int64>()();
|
||||
|
||||
std::unordered_set<int32> trees_not_to_drop;
|
||||
if (center_bias_) {
|
||||
trees_not_to_drop.insert(0);
|
||||
}
|
||||
if (ensemble_resource->decision_tree_ensemble().has_growing_metadata()) {
|
||||
// We are in batch mode, the last tree is the tree that is being built,
|
||||
// we can't drop it during dropout.
|
||||
trees_not_to_drop.insert(ensemble_resource->num_trees() - 1);
|
||||
}
|
||||
const std::vector<float> weights = ensemble_resource->GetTreeWeights();
|
||||
OP_REQUIRES_OK(context, DropoutUtils::DropOutTrees(
|
||||
seed, dropout_config_, trees_not_to_drop,
|
||||
weights, &dropped_trees, &original_weights));
|
||||
}
|
||||
|
||||
// Prepare the list of trees to include in the prediction.
|
||||
std::vector<int32> trees_to_include;
|
||||
CalculateTreesToInclude(
|
||||
ensemble_resource->decision_tree_ensemble(), dropped_trees,
|
||||
ensemble_resource->decision_tree_ensemble().trees_size(),
|
||||
only_finalized_trees_, center_bias_, &trees_to_include);
|
||||
|
||||
// Allocate output predictions matrix.
|
||||
Tensor* output_predictions_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(kPredictionsTensorName,
|
||||
{batch_size, prediction_vector_size_},
|
||||
&output_predictions_t));
|
||||
auto output_predictions = output_predictions_t->matrix<float>();
|
||||
|
||||
// Allocate output leaf index matrix.
|
||||
Tensor* output_leaf_index_t = nullptr;
|
||||
if (return_output_leaf_index) {
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
kLeafIndexTensorName,
|
||||
{batch_size, ensemble_resource->num_trees()},
|
||||
&output_leaf_index_t));
|
||||
}
|
||||
// Run predictor.
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
|
||||
if (apply_averaging_) {
|
||||
DecisionTreeEnsembleConfig adjusted =
|
||||
ensemble_resource->decision_tree_ensemble();
|
||||
const int start_averaging = std::max(
|
||||
0.0,
|
||||
averaging_config_.config_case() ==
|
||||
AveragingConfig::kAverageLastNTreesFieldNumber
|
||||
? adjusted.trees_size() - averaging_config_.average_last_n_trees()
|
||||
: adjusted.trees_size() *
|
||||
(1.0 - averaging_config_.average_last_percent_trees()));
|
||||
const int num_ensembles = adjusted.trees_size() - start_averaging;
|
||||
for (int i = start_averaging; i < adjusted.trees_size(); ++i) {
|
||||
float weight = adjusted.tree_weights(i);
|
||||
adjusted.mutable_tree_weights()->Set(
|
||||
i, weight * (num_ensembles - i + start_averaging) / num_ensembles);
|
||||
}
|
||||
MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features,
|
||||
worker_threads, output_predictions,
|
||||
output_leaf_index_t);
|
||||
} else {
|
||||
MultipleAdditiveTrees::Predict(
|
||||
ensemble_resource->decision_tree_ensemble(), trees_to_include,
|
||||
batch_features, worker_threads, output_predictions,
|
||||
output_leaf_index_t);
|
||||
}
|
||||
|
||||
// Output dropped trees and original weights.
|
||||
Tensor* output_dropout_info_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
kDropoutInfoOutputTensorName,
|
||||
{2, static_cast<int64>(dropped_trees.size())},
|
||||
&output_dropout_info_t));
|
||||
auto output_dropout_info = output_dropout_info_t->matrix<float>();
|
||||
for (int32 i = 0; i < dropped_trees.size(); ++i) {
|
||||
output_dropout_info(0, i) = dropped_trees[i];
|
||||
output_dropout_info(1, i) = original_weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
LearningRateDropoutDrivenConfig dropout_config_;
|
||||
AveragingConfig averaging_config_;
|
||||
bool only_finalized_trees_;
|
||||
int num_classes_;
|
||||
// What is the size of the output vector for predictions?
|
||||
int prediction_vector_size_;
|
||||
bool apply_dropout_;
|
||||
bool center_bias_;
|
||||
bool apply_averaging_;
|
||||
bool use_locking_;
|
||||
bool has_dropout_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU),
|
||||
GradientTreesPredictionOp);
|
||||
|
||||
// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp
|
||||
// and have an additional output of tensor of rank 2 containing leaf ids for
|
||||
// each tree where an instance ended up with.
|
||||
class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
|
||||
public:
|
||||
explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context)
|
||||
: GradientTreesPredictionOp(context) {}
|
||||
|
||||
protected:
|
||||
void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
bool return_output_leaf_index) override {
|
||||
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
|
||||
/*return_output_leaf_index=*/true);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU),
|
||||
GradientTreesPredictionVerboseOp);
|
||||
|
||||
class GradientTreesPartitionExamplesOp : public OpKernel {
|
||||
public:
|
||||
explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("use_locking", &use_locking_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
// Gets the resource. Grabs the mutex but releases it.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
if (use_locking_) {
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
DoCompute(context, ensemble_resource);
|
||||
} else {
|
||||
DoCompute(context, ensemble_resource);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void DoCompute(
|
||||
OpKernelContext* context,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
// The last non-finalized tree in the ensemble is by convention the
|
||||
// one to partition on. If no such tree exists, a nodeless tree is
|
||||
// created.
|
||||
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
|
||||
const boosted_trees::trees::DecisionTreeConfig& tree_config =
|
||||
(resource->num_trees() <= 0 ||
|
||||
resource->LastTreeMetadata()->is_finalized())
|
||||
? empty_tree_config
|
||||
: *resource->LastTree();
|
||||
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||
context, &dense_float_features_list));
|
||||
|
||||
// Read sparse float features list;
|
||||
OpInputList sparse_float_feature_indices_list;
|
||||
OpInputList sparse_float_feature_values_list;
|
||||
OpInputList sparse_float_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
|
||||
context, &sparse_float_feature_indices_list,
|
||||
&sparse_float_feature_values_list,
|
||||
&sparse_float_feature_shapes_list));
|
||||
|
||||
// Read sparse int features list;
|
||||
OpInputList sparse_int_feature_indices_list;
|
||||
OpInputList sparse_int_feature_values_list;
|
||||
OpInputList sparse_int_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseIntFeatures(
|
||||
context, &sparse_int_feature_indices_list,
|
||||
&sparse_int_feature_values_list,
|
||||
&sparse_int_feature_shapes_list));
|
||||
|
||||
// Infer batch size.
|
||||
const int64 batch_size = TensorUtils::InferBatchSize(
|
||||
dense_float_features_list, sparse_float_feature_shapes_list,
|
||||
sparse_int_feature_shapes_list);
|
||||
|
||||
// Read batch features.
|
||||
boosted_trees::utils::BatchFeatures batch_features(batch_size);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
batch_features.Initialize(
|
||||
TensorUtils::OpInputListToTensorVec(dense_float_features_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_indices_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_values_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_float_feature_shapes_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_int_feature_indices_list),
|
||||
TensorUtils::OpInputListToTensorVec(sparse_int_feature_values_list),
|
||||
TensorUtils::OpInputListToTensorVec(
|
||||
sparse_int_feature_shapes_list)));
|
||||
|
||||
// Allocate output partitions vector.
|
||||
Tensor* partition_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, {batch_size}, &partition_ids_t));
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
learner::ExamplePartitioner::PartitionExamples(
|
||||
tree_config, batch_features, worker_threads->NumThreads(),
|
||||
worker_threads, partition_ids_t->vec<int32>().data());
|
||||
}
|
||||
|
||||
private:
|
||||
bool use_locking_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("GradientTreesPartitionExamples").Device(DEVICE_CPU),
|
||||
GradientTreesPartitionExamplesOp);
|
||||
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -1,985 +0,0 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ::boosted_trees::QuantileConfig;
|
||||
using boosted_trees::QuantileStreamResource;
|
||||
using boosted_trees::utils::TensorUtils;
|
||||
|
||||
namespace {
|
||||
const char* const kExampleWeightsName = "example_weights";
|
||||
const char* const kMaxElementsName = "max_elements";
|
||||
const char* const kNextStampTokenName = "next_stamp_token";
|
||||
const char* const kStampTokenName = "stamp_token";
|
||||
const char* const kAreBucketsReadyName = "are_buckets_ready";
|
||||
const char* const kGenerateQuantiles = "generate_quantiles";
|
||||
// Names for sparse arguments.
|
||||
const char* const kNumSparseFeaturesName = "num_sparse_features";
|
||||
const char* const kSparseBucketsName = "sparse_buckets";
|
||||
const char* const kSparseValuesName = "sparse_values";
|
||||
const char* const kSparseIndicesName = "sparse_indices";
|
||||
const char* const kSparseSummariesName = "sparse_summaries";
|
||||
const char* const kSparseConfigName = "sparse_config";
|
||||
const char* const kSparseOutputTensorName = "sparse_quantiles";
|
||||
// Names for dense arguments.
|
||||
const char* const kDenseBucketsName = "dense_buckets";
|
||||
const char* const kDenseConfigName = "dense_config";
|
||||
const char* const kDenseOutputTensorName = "dense_quantiles";
|
||||
const char* const kDenseSummariesName = "dense_summaries";
|
||||
const char* const kDenseValuesName = "dense_values";
|
||||
const char* const kNumDenseFeaturesName = "num_dense_features";
|
||||
const char* const kResourceHandlesName = "quantile_accumulator_handles";
|
||||
const char* const kNumQuantilesName = "num_quantiles";
|
||||
const char* const kEpsilonName = "epsilon";
|
||||
const char* const kBucketsName = "buckets";
|
||||
const char* const kStreamStateName = "stream_state";
|
||||
const char* const kSummariesName = "summaries";
|
||||
|
||||
using QuantileStream =
|
||||
boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
|
||||
using QuantileSummary =
|
||||
boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
|
||||
using QuantileSummaryEntry =
|
||||
boosted_trees::quantiles::WeightedQuantilesSummary<float,
|
||||
float>::SummaryEntry;
|
||||
|
||||
std::vector<float> GetBuckets(const int32 feature,
|
||||
const OpInputList& buckets_list) {
|
||||
const auto& buckets = buckets_list[feature].flat<float>();
|
||||
std::vector<float> buckets_vector(buckets.data(),
|
||||
buckets.data() + buckets.size());
|
||||
return buckets_vector;
|
||||
}
|
||||
|
||||
int32 GetFeatureDimension(const int32 feature_index, const int64 instance,
|
||||
const OpInputList* const indices_list) {
|
||||
if (indices_list != nullptr) {
|
||||
// Sparse multidimensional.
|
||||
return (*indices_list)[feature_index].matrix<int64>()(instance, 1);
|
||||
}
|
||||
// No indices, assume one-dimensional tensor.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Allows quantization for each of multiple dimensions of a sparse feature.
|
||||
void QuantizeFeatures(
|
||||
const string& output_name, const OpInputList& values_list,
|
||||
const OpInputList& buckets_list,
|
||||
const OpInputList* const
|
||||
indices_list /** Optional, provide for sparse features **/,
|
||||
OpKernelContext* const context) {
|
||||
if (values_list.size() == 0) {
|
||||
return;
|
||||
}
|
||||
OpOutputList output_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(output_name, &output_list));
|
||||
|
||||
for (int32 feature_index = 0; feature_index < values_list.size();
|
||||
++feature_index) {
|
||||
const Tensor& values_tensor = values_list[feature_index];
|
||||
const int64 num_values = values_tensor.dim_size(0);
|
||||
|
||||
Tensor* output_t = nullptr;
|
||||
// Output will have bucket id and dimension of the features for that bucket.
|
||||
OP_REQUIRES_OK(
|
||||
context, output_list.allocate(feature_index,
|
||||
TensorShape({num_values, 2}), &output_t));
|
||||
|
||||
auto output = output_t->matrix<int32>();
|
||||
|
||||
const std::vector<float>& buckets_vector =
|
||||
GetBuckets(feature_index, buckets_list);
|
||||
auto flat_values = values_tensor.flat<float>();
|
||||
for (int64 instance = 0; instance < num_values; ++instance) {
|
||||
const float value = flat_values(instance);
|
||||
CHECK(!buckets_vector.empty())
|
||||
<< "Got empty buckets for feature " << feature_index;
|
||||
auto bucket_iter =
|
||||
std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
|
||||
if (bucket_iter == buckets_vector.end()) {
|
||||
--bucket_iter;
|
||||
}
|
||||
const int32 bucket =
|
||||
static_cast<int32>(bucket_iter - buckets_vector.begin());
|
||||
// Bucket id.
|
||||
output(instance, 0) = bucket;
|
||||
// Dimension.
|
||||
output(instance, 1) =
|
||||
GetFeatureDimension(feature_index, instance, indices_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validates attributes for the quantile ops.
|
||||
Status ReadAndValidateAttributes(OpKernelConstruction* const context,
|
||||
int* num_dense_features,
|
||||
int* num_sparse_features) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->GetAttr(kNumDenseFeaturesName, num_dense_features));
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->GetAttr(kNumSparseFeaturesName, num_sparse_features));
|
||||
if ((*num_dense_features) + (*num_sparse_features) == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Please provide at least sparse or dense features.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ParseConfig(OpKernelConstruction* const context, const string& name,
|
||||
std::vector<QuantileConfig>* output) {
|
||||
std::vector<string> serialized_config;
|
||||
OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config));
|
||||
output->reserve(serialized_config.size());
|
||||
QuantileConfig tmp;
|
||||
for (const auto& serialized_string : serialized_config) {
|
||||
OP_REQUIRES(context, tmp.ParseFromString(serialized_string),
|
||||
errors::InvalidArgument("Malformed QuantileConfig passed in."));
|
||||
output->push_back(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
// Generates quantiles on a finalized QuantileStream.
|
||||
std::vector<float> GenerateBoundaries(const QuantileStream& stream,
|
||||
int num_boundaries) {
|
||||
std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
|
||||
|
||||
// Uniquify elements as we may get dupes.
|
||||
auto end_it = std::unique(boundaries.begin(), boundaries.end());
|
||||
boundaries.resize(std::distance(boundaries.begin(), end_it));
|
||||
return boundaries;
|
||||
}
|
||||
|
||||
// Generates quantiles on a finalized QuantileStream.
|
||||
std::vector<float> GenerateQuantiles(const QuantileStream& stream,
|
||||
int num_quantiles) {
|
||||
// Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
|
||||
// will be returned.
|
||||
std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
|
||||
CHECK_EQ(boundaries.size(), num_quantiles + 1);
|
||||
return boundaries;
|
||||
}
|
||||
|
||||
// Copies quantiles to output list.
|
||||
void CopyBoundaries(OpKernelContext* const context,
|
||||
const std::vector<float>& boundaries, const int64 index,
|
||||
OpOutputList* output_list) {
|
||||
// Output to tensor.
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, output_list->allocate(
|
||||
index, {static_cast<int64>(boundaries.size())}, &output_t));
|
||||
auto* quantiles_flat = output_t->flat<float>().data();
|
||||
memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size());
|
||||
}
|
||||
|
||||
void CopySummaryToProto(const QuantileSummary& summary,
|
||||
::boosted_trees::QuantileSummaryState* summary_proto) {
|
||||
summary_proto->mutable_entries()->Reserve(summary.Size());
|
||||
for (const auto& entry : summary.GetEntryList()) {
|
||||
auto* new_entry = summary_proto->add_entries();
|
||||
new_entry->set_value(entry.value);
|
||||
new_entry->set_weight(entry.weight);
|
||||
new_entry->set_min_rank(entry.min_rank);
|
||||
new_entry->set_max_rank(entry.max_rank);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Accumulator for Quantile Summaries.
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<QuantileStreamResource>);
|
||||
|
||||
class CreateQuantileAccumulatorOp : public OpKernel {
|
||||
public:
|
||||
explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr(kNumQuantilesName, &num_quantiles_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Only create one, if one does not exist already. Report status for all
|
||||
// other exceptions. If one already exists, it unrefs the new one.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
// An epsilon value of zero could cause perfoamance issues and is therefore,
|
||||
// disallowed.
|
||||
OP_REQUIRES(
|
||||
context, epsilon_ > 0,
|
||||
errors::InvalidArgument("An epsilon value of zero is not allowed."));
|
||||
auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
|
||||
max_elements_, generate_quantiles_,
|
||||
stamp_token_t->scalar<int64>()());
|
||||
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
OP_REQUIRES(context, false, status);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
int32 num_quantiles_;
|
||||
// An upper bound on the number of entries that the summaries might have
|
||||
// for a feature.
|
||||
int64 max_elements_;
|
||||
bool generate_quantiles_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
|
||||
CreateQuantileAccumulatorOp);
|
||||
|
||||
// Adds a summary to the quantile summary stream.
|
||||
class QuantileAccumulatorAddSummariesOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorAddSummariesOp(
|
||||
OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
OpInputList resource_handle_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
|
||||
&resource_handle_list));
|
||||
OpInputList summary_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list));
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
boosted_trees::utils::ParallelFor(
|
||||
resource_handle_list.size(), worker_threads->NumThreads(),
|
||||
worker_threads,
|
||||
[&context, &resource_handle_list, &summary_list, stamp_token](
|
||||
int64 start, int64 end) {
|
||||
for (int resource_handle_idx = start; resource_handle_idx < end;
|
||||
++resource_handle_idx) {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, handle, &streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!streams_resource->is_stamp_valid(stamp_token)) {
|
||||
VLOG(1)
|
||||
<< "Invalid stamp token in QuantileAccumulatorAddSummariesOp."
|
||||
<< " Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << streams_resource->stamp();
|
||||
return;
|
||||
}
|
||||
|
||||
protobuf::Arena arena;
|
||||
::boosted_trees::QuantileSummaryState* summary_proto =
|
||||
protobuf::Arena::CreateMessage<
|
||||
::boosted_trees::QuantileSummaryState>(&arena);
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ParseProtoUnlimited(
|
||||
summary_proto,
|
||||
summary_list[resource_handle_idx].scalar<tstring>()()),
|
||||
errors::InvalidArgument("Unable to parse quantile summary."));
|
||||
std::vector<QuantileSummaryEntry> entries;
|
||||
entries.reserve(summary_proto->entries_size());
|
||||
for (const auto& entry : summary_proto->entries()) {
|
||||
entries.emplace_back(entry.value(), entry.weight(),
|
||||
entry.min_rank(), entry.max_rank());
|
||||
}
|
||||
|
||||
// Add the summary to the quantile stream.
|
||||
streams_resource->stream(stamp_token)->PushSummary(entries);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorAddSummariesOp);
|
||||
|
||||
// Generates summaries for given set of float values, and the given config.
|
||||
class MakeQuantileSummariesOp : public OpKernel {
|
||||
public:
|
||||
explicit MakeQuantileSummariesOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
ReadAndValidateAttributes(context, &num_dense_features_,
|
||||
&num_sparse_features_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||
context, &dense_float_features_list));
|
||||
|
||||
// Read sparse float features list;
|
||||
OpInputList sparse_float_feature_indices_list;
|
||||
OpInputList sparse_float_feature_values_list;
|
||||
OpInputList sparse_float_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
|
||||
context, &sparse_float_feature_indices_list,
|
||||
&sparse_float_feature_values_list,
|
||||
&sparse_float_feature_shapes_list));
|
||||
|
||||
// Parse example weights and get batch size.
|
||||
const Tensor* example_weights_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kExampleWeightsName, &example_weights_t));
|
||||
auto example_weights = example_weights_t->flat<float>();
|
||||
const int64 batch_size = example_weights.size();
|
||||
|
||||
OpOutputList sparse_summaries_output_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->output_list(kSparseSummariesName,
|
||||
&sparse_summaries_output_list));
|
||||
OpOutputList dense_summaries_output_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName,
|
||||
&dense_summaries_output_list));
|
||||
|
||||
auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
|
||||
auto copy_over_summaries = [&](const QuantileStream& stream,
|
||||
const int64 index,
|
||||
OpOutputList* output_list) {
|
||||
protobuf::Arena arena;
|
||||
::boosted_trees::QuantileSummaryState* summary_proto =
|
||||
protobuf::Arena::CreateMessage<
|
||||
::boosted_trees::QuantileSummaryState>(&arena);
|
||||
const auto& summary = stream.GetFinalSummary();
|
||||
CopySummaryToProto(summary, summary_proto);
|
||||
// Output to tensor.
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t));
|
||||
SerializeToTString(*summary_proto, &output_t->scalar<tstring>()());
|
||||
};
|
||||
|
||||
// These are blocks of ranges. We are iterating over both sparse and
|
||||
// dense features i.e. [0, sparse_features.size() + dense_features.size()]
|
||||
for (int64 i = begin; i < end; ++i) {
|
||||
if (i < num_dense_features_) {
|
||||
const int64 dense_index = i;
|
||||
const auto dense_values =
|
||||
dense_float_features_list[dense_index].flat<float>();
|
||||
QuantileStream stream(epsilon_, batch_size + 1);
|
||||
// Run quantile summary generation.
|
||||
for (int64 j = 0; j < batch_size; ++j) {
|
||||
stream.PushEntry(dense_values(j), example_weights(j));
|
||||
}
|
||||
stream.Finalize();
|
||||
// Copy summaries to output.
|
||||
copy_over_summaries(stream, dense_index,
|
||||
&dense_summaries_output_list);
|
||||
} else {
|
||||
const int64 sparse_index = i - num_dense_features_;
|
||||
const auto sparse_values =
|
||||
sparse_float_feature_values_list[sparse_index].flat<float>();
|
||||
const auto sparse_indices =
|
||||
sparse_float_feature_indices_list[sparse_index].matrix<int64>();
|
||||
const auto dense_shape =
|
||||
sparse_float_feature_shapes_list[sparse_index].flat<int64>();
|
||||
OP_REQUIRES(context, batch_size == dense_shape(0),
|
||||
errors::InvalidArgument(
|
||||
"Sparse column shape doesn't match the batch size."));
|
||||
QuantileStream stream(epsilon_, batch_size + 1);
|
||||
// Run quantile summary generation.
|
||||
const int64 num_sparse_rows =
|
||||
sparse_float_feature_indices_list[sparse_index].dim_size(0);
|
||||
for (int64 j = 0; j < num_sparse_rows; ++j) {
|
||||
const int64 example_id = sparse_indices(j, 0);
|
||||
stream.PushEntry(sparse_values(j), example_weights(example_id));
|
||||
}
|
||||
stream.Finalize();
|
||||
// Copy summaries to output.
|
||||
copy_over_summaries(stream, sparse_index,
|
||||
&sparse_summaries_output_list);
|
||||
}
|
||||
}
|
||||
};
|
||||
const int64 kCostPerUnit = 500 * batch_size;
|
||||
const int64 num_features = num_sparse_features_ + num_dense_features_;
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*context->device()->tensorflow_cpu_worker_threads();
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, num_features,
|
||||
kCostPerUnit, do_quantile_summary_gen);
|
||||
}
|
||||
|
||||
private:
|
||||
int num_dense_features_;
|
||||
int num_sparse_features_;
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU),
|
||||
MakeQuantileSummariesOp);
|
||||
|
||||
// Serializes the state of streams.
|
||||
class QuantileAccumulatorSerializeOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
int64 stamp_token = streams_resource->stamp();
|
||||
Tensor* stream_state_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(kStreamStateName, TensorShape({}),
|
||||
&stream_state_t));
|
||||
bool are_buckets_ready = streams_resource->are_buckets_ready();
|
||||
|
||||
// We are iterating over both dense and sparse features. First we go
|
||||
// through the dense features and then the sparse features.
|
||||
const QuantileStream& stream = *streams_resource->stream(stamp_token);
|
||||
const std::vector<float>& boundaries =
|
||||
are_buckets_ready ? streams_resource->boundaries(stamp_token)
|
||||
: std::vector<float>();
|
||||
protobuf::Arena arena;
|
||||
::boosted_trees::QuantileStreamState* stream_proto =
|
||||
protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>(
|
||||
&arena);
|
||||
for (const auto& summary : stream.SerializeInternalSummaries()) {
|
||||
CopySummaryToProto(summary, stream_proto->add_summaries());
|
||||
}
|
||||
SerializeToTString(*stream_proto, &stream_state_t->scalar<tstring>()());
|
||||
Tensor* buckets_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
context->allocate_output(
|
||||
kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t));
|
||||
auto* quantiles_flat = buckets_t->flat<float>().data();
|
||||
memcpy(quantiles_flat, boundaries.data(),
|
||||
sizeof(float) * boundaries.size());
|
||||
Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(kStampTokenName, TensorShape({}),
|
||||
&stamp_token_t));
|
||||
stamp_token_t->scalar<int64>()() = stamp_token;
|
||||
Tensor* are_buckets_ready_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {},
|
||||
&are_buckets_ready_t));
|
||||
are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorSerializeOp);
|
||||
|
||||
// Serializes the state of streams.
|
||||
class QuantileAccumulatorDeserializeOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
int64 old_stamp_token = streams_resource->stamp();
|
||||
|
||||
const Tensor* stream_state_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t));
|
||||
const Tensor* buckets_t;
|
||||
OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t));
|
||||
|
||||
QuantileStream* stream = streams_resource->stream(old_stamp_token);
|
||||
::boosted_trees::QuantileStreamState state_proto;
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
ParseProtoUnlimited(&state_proto, stream_state_t->scalar<tstring>()()),
|
||||
errors::InvalidArgument("Unabnle to parse quantile stream state."));
|
||||
std::vector<QuantileSummary> summaries;
|
||||
summaries.reserve(state_proto.summaries_size());
|
||||
std::vector<QuantileSummaryEntry> entries;
|
||||
for (const auto& summary : state_proto.summaries()) {
|
||||
entries.clear();
|
||||
entries.reserve(summary.entries_size());
|
||||
for (const auto& entry : summary.entries()) {
|
||||
entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(),
|
||||
entry.max_rank());
|
||||
}
|
||||
summaries.emplace_back();
|
||||
summaries[summaries.size() - 1].BuildFromSummaryEntries(entries);
|
||||
}
|
||||
stream->DeserializeInternalSummaries(summaries);
|
||||
|
||||
const auto& buckets = buckets_t->vec<float>();
|
||||
std::vector<float> result;
|
||||
result.reserve(buckets.size());
|
||||
|
||||
for (size_t i = 0; i < buckets.size(); ++i) {
|
||||
result.push_back(buckets(i));
|
||||
}
|
||||
streams_resource->set_boundaries(old_stamp_token, result);
|
||||
|
||||
// Reset the stamp token.
|
||||
const Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
streams_resource->set_stamp(stamp_token);
|
||||
|
||||
const Tensor* are_buckets_ready_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kAreBucketsReadyName, &are_buckets_ready_t));
|
||||
streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorDeserializeOp);
|
||||
|
||||
// Flushes the quantile summary stream resource.
|
||||
class QuantileAccumulatorFlushOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kNextStampTokenName, &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
CHECK(streams_resource->is_stamp_valid(stamp_token))
|
||||
<< "Invalid stamp token in QuantileAccumulatorFlushOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << streams_resource->stamp();
|
||||
QuantileStream* stream = streams_resource->stream(stamp_token);
|
||||
bool generate_quantiles = streams_resource->generate_quantiles();
|
||||
stream->Finalize();
|
||||
|
||||
streams_resource->set_boundaries(
|
||||
stamp_token,
|
||||
generate_quantiles
|
||||
? GenerateQuantiles(*stream, streams_resource->num_quantiles())
|
||||
: GenerateBoundaries(*stream, streams_resource->num_quantiles()));
|
||||
|
||||
streams_resource->Reset(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorFlushOp);
|
||||
|
||||
// Flushes the quantile summary stream resource. This version computes the
|
||||
// summary.
|
||||
class QuantileAccumulatorFlushSummaryOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorFlushSummaryOp(
|
||||
OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
// Create a reference to the underlying resource using the handle.
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kNextStampTokenName, &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
CHECK(streams_resource->is_stamp_valid(stamp_token))
|
||||
<< "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << streams_resource->stamp();
|
||||
QuantileStream* stream = streams_resource->stream(stamp_token);
|
||||
stream->Finalize();
|
||||
protobuf::Arena arena;
|
||||
::boosted_trees::QuantileSummaryState* summary_proto =
|
||||
protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>(
|
||||
&arena);
|
||||
const auto& summary = stream->GetFinalSummary();
|
||||
CopySummaryToProto(summary, summary_proto);
|
||||
// Output to tensor.
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({}), &output_t));
|
||||
SerializeToTString(*summary_proto, &output_t->scalar<tstring>()());
|
||||
streams_resource->Reset(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorFlushSummaryOp);
|
||||
|
||||
// Get bucket boundaries from summaries.
|
||||
class QuantileAccumulatorGetBucketsOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
OpInputList resource_handle_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
|
||||
&resource_handle_list));
|
||||
OpOutputList are_buckets_ready_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName,
|
||||
&are_buckets_ready_list));
|
||||
OpOutputList buckets_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
boosted_trees::utils::ParallelFor(
|
||||
resource_handle_list.size(), worker_threads->NumThreads(),
|
||||
worker_threads,
|
||||
[&context, &resource_handle_list, &are_buckets_ready_list,
|
||||
&buckets_list, stamp_token](int64 start, int64 end) {
|
||||
for (int resource_handle_idx = start; resource_handle_idx < end;
|
||||
++resource_handle_idx) {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
core::RefCountPtr<QuantileStreamResource> streams_resource;
|
||||
OP_REQUIRES_OK(context,
|
||||
LookupResource(context, handle, &streams_resource));
|
||||
// Remove the reference at the end of this scope.
|
||||
mutex_lock l(*streams_resource->mutex());
|
||||
|
||||
bool are_buckets_ready =
|
||||
streams_resource->is_stamp_valid(stamp_token) &&
|
||||
streams_resource->are_buckets_ready();
|
||||
|
||||
Tensor* are_buckets_ready_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
are_buckets_ready_list.allocate(
|
||||
resource_handle_idx, {}, &are_buckets_ready_t));
|
||||
are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
|
||||
|
||||
const std::vector<float>& boundaries =
|
||||
are_buckets_ready ? streams_resource->boundaries(stamp_token)
|
||||
: std::vector<float>();
|
||||
Tensor* output_t = nullptr;
|
||||
OP_REQUIRES_OK(context, buckets_list.allocate(
|
||||
resource_handle_idx,
|
||||
{static_cast<int64>(boundaries.size())},
|
||||
&output_t));
|
||||
auto* quantiles_flat = output_t->flat<float>().data();
|
||||
memcpy(quantiles_flat, boundaries.data(),
|
||||
sizeof(float) * boundaries.size());
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU),
|
||||
QuantileAccumulatorGetBucketsOp);
|
||||
|
||||
// Generates buckets for given set of float values, and the given config.
|
||||
class QuantileBucketsOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantileBucketsOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
ReadAndValidateAttributes(context, &num_dense_features_,
|
||||
&num_sparse_features_));
|
||||
|
||||
ParseConfig(context, kDenseConfigName, &dense_configs_);
|
||||
OP_REQUIRES(context, dense_configs_.size() == num_dense_features_,
|
||||
errors::InvalidArgument(
|
||||
"Mismatch in number of dense quantile configs."));
|
||||
ParseConfig(context, kSparseConfigName, &sparse_configs_);
|
||||
OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_,
|
||||
errors::InvalidArgument(
|
||||
"Mismatch in number of sparse quantile configs."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Read dense float features list;
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
|
||||
context, &dense_float_features_list));
|
||||
|
||||
// Read sparse float features list;
|
||||
OpInputList sparse_float_feature_indices_list;
|
||||
OpInputList sparse_float_feature_values_list;
|
||||
OpInputList sparse_float_feature_shapes_list;
|
||||
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
|
||||
context, &sparse_float_feature_indices_list,
|
||||
&sparse_float_feature_values_list,
|
||||
&sparse_float_feature_shapes_list));
|
||||
|
||||
// Parse example weights and get batch size.
|
||||
const Tensor* example_weights_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kExampleWeightsName, &example_weights_t));
|
||||
auto example_weights = example_weights_t->flat<float>();
|
||||
const int64 batch_size = example_weights.size();
|
||||
|
||||
OpOutputList sparse_buckets_output_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName,
|
||||
&sparse_buckets_output_list));
|
||||
OpOutputList dense_buckets_output_list;
|
||||
OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName,
|
||||
&dense_buckets_output_list));
|
||||
|
||||
auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) {
|
||||
// These are blocks of ranges. We are iterating over both sparse and
|
||||
// dense features i.e. [0, sparse_features.size() + dense_features.size()]
|
||||
for (int64 i = begin; i < end; ++i) {
|
||||
if (i < sparse_configs_.size()) {
|
||||
const int64 sparse_index = i;
|
||||
const auto sparse_values =
|
||||
sparse_float_feature_values_list[sparse_index].flat<float>();
|
||||
const auto sparse_indices =
|
||||
sparse_float_feature_indices_list[sparse_index].matrix<int64>();
|
||||
QuantileStream stream(sparse_configs_[sparse_index].eps(),
|
||||
batch_size);
|
||||
// Run quantile summary generation.
|
||||
const int64 num_sparse_rows =
|
||||
sparse_float_feature_indices_list[sparse_index].dim_size(0);
|
||||
for (int64 j = 0; j < num_sparse_rows; ++j) {
|
||||
const int64 example_id = sparse_indices(j, 0);
|
||||
stream.PushEntry(sparse_values(j), example_weights(example_id));
|
||||
}
|
||||
stream.Finalize();
|
||||
// Create buckets.
|
||||
const auto boundaries = GenerateBoundaries(
|
||||
stream, sparse_configs_[sparse_index].num_quantiles());
|
||||
CopyBoundaries(context, boundaries, sparse_index,
|
||||
&sparse_buckets_output_list);
|
||||
|
||||
} else {
|
||||
const int64 dense_index = i - sparse_configs_.size();
|
||||
const auto dense_values =
|
||||
dense_float_features_list[dense_index].flat<float>();
|
||||
QuantileStream stream(dense_configs_[dense_index].eps(), batch_size);
|
||||
// Run quantile summary generation.
|
||||
for (int64 j = 0; j < batch_size; ++j) {
|
||||
stream.PushEntry(dense_values(j), example_weights(j));
|
||||
}
|
||||
stream.Finalize();
|
||||
// Create buckets.
|
||||
const auto boundaries = GenerateBoundaries(
|
||||
stream, dense_configs_[dense_index].num_quantiles());
|
||||
CopyBoundaries(context, boundaries, dense_index,
|
||||
&dense_buckets_output_list);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const int64 kCostPerUnit = 500 * batch_size;
|
||||
const int64 num_features = sparse_configs_.size() + dense_configs_.size();
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*context->device()->tensorflow_cpu_worker_threads();
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, num_features,
|
||||
kCostPerUnit, do_quantile_bucket_gen);
|
||||
}
|
||||
|
||||
private:
|
||||
int num_dense_features_;
|
||||
int num_sparse_features_;
|
||||
std::vector<QuantileConfig> dense_configs_;
|
||||
std::vector<QuantileConfig> sparse_configs_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU),
|
||||
QuantileBucketsOp);
|
||||
|
||||
// Given the calculated quantiles thresholds and input data, this operation
|
||||
// converts the input features into the buckets (categorical values), depending
|
||||
// on which quantile they fall into.
|
||||
class QuantilesOp : public OpKernel {
|
||||
public:
|
||||
explicit QuantilesOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
int num_dense_features;
|
||||
int num_sparse_features;
|
||||
OP_REQUIRES_OK(context,
|
||||
ReadAndValidateAttributes(context, &num_dense_features,
|
||||
&num_sparse_features));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Dense features inputs
|
||||
OpInputList dense_float_features_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kDenseValuesName,
|
||||
&dense_float_features_list));
|
||||
OpInputList dense_buckets_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list(kDenseBucketsName, &dense_buckets_list));
|
||||
|
||||
if (dense_buckets_list.size() > 0) {
|
||||
// Check the first tensor to make sure it is the right shape
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()),
|
||||
errors::InvalidArgument(
|
||||
strings::Printf("Dense buckets should be flat vectors")));
|
||||
}
|
||||
|
||||
// Sparse features inputs
|
||||
OpInputList sparse_float_feature_values_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list(kSparseValuesName,
|
||||
&sparse_float_feature_values_list));
|
||||
|
||||
OpInputList sparse_float_indices_list;
|
||||
OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName,
|
||||
&sparse_float_indices_list));
|
||||
|
||||
OpInputList sparse_buckets_list;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input_list(kSparseBucketsName, &sparse_buckets_list));
|
||||
|
||||
if (sparse_buckets_list.size() > 0) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
tensorflow::TensorShapeUtils::IsVector(
|
||||
sparse_buckets_list[0].shape()),
|
||||
errors::InvalidArgument("Sparse buckets should be flat vectors"));
|
||||
}
|
||||
|
||||
// Quantize the feature values
|
||||
QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list,
|
||||
dense_buckets_list, nullptr, context);
|
||||
|
||||
QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list,
|
||||
sparse_buckets_list, &sparse_float_indices_list, context);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp);
|
||||
|
||||
template <typename T>
|
||||
class BucketizeWithInputBoundariesOp : public OpKernel {
|
||||
public:
|
||||
explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& boundaries_tensor = context->input(1);
|
||||
VLOG(1) << "boundaries has shape: "
|
||||
<< boundaries_tensor.shape().DebugString();
|
||||
auto boundaries = boundaries_tensor.flat<float>();
|
||||
std::vector<T> boundaries_vector;
|
||||
boundaries_vector.reserve(boundaries.size());
|
||||
for (size_t i = 0; i < boundaries.size(); i++) {
|
||||
boundaries_vector.push_back(boundaries(i));
|
||||
VLOG(1) << "boundaries(" << i << ") : " << boundaries(i);
|
||||
}
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()),
|
||||
errors::InvalidArgument("Expected sorted boundaries"));
|
||||
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString()
|
||||
<< " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype());
|
||||
auto input = input_tensor.flat<T>();
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
auto output = output_tensor->template flat<int32>();
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++) {
|
||||
output(i) = CalculateBucketIndex(input(i), boundaries_vector);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) {
|
||||
auto first_bigger_it = std::upper_bound(boundaries_vector.begin(),
|
||||
boundaries_vector.end(), value);
|
||||
int32 index = first_bigger_it - boundaries_vector.begin();
|
||||
CHECK(index >= 0 && index <= boundaries_vector.size())
|
||||
<< "Invalid bucket index: " << index
|
||||
<< " boundaries_vector.size(): " << boundaries_vector.size();
|
||||
return index;
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BucketizeWithInputBoundariesOp<T>);
|
||||
|
||||
REGISTER_KERNEL(int32);
|
||||
REGISTER_KERNEL(int64);
|
||||
REGISTER_KERNEL(float);
|
||||
REGISTER_KERNEL(double);
|
||||
#undef REGISTER_KERNEL
|
||||
|
||||
} // namespace tensorflow
|
File diff suppressed because it is too large
Load Diff
@ -1,782 +0,0 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace boosted_trees {
|
||||
|
||||
namespace {
|
||||
const char* const kStampTokenName = "stamp_token";
|
||||
const char* const kNextStampTokenName = "next_stamp_token";
|
||||
|
||||
struct PartitionKey {
|
||||
PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
|
||||
|
||||
PartitionKey(int32 p, int64 f, int32 d)
|
||||
: partition_id(p), feature_id(f), dimension(d) {}
|
||||
|
||||
bool operator==(const PartitionKey& other) const {
|
||||
return (partition_id == other.partition_id) &&
|
||||
(dimension == other.dimension) && (feature_id == other.feature_id);
|
||||
}
|
||||
|
||||
// Compare for PartitionKey.
|
||||
struct Less {
|
||||
bool operator()(const PartitionKey& a, const PartitionKey& b) const {
|
||||
if (a.partition_id < b.partition_id) {
|
||||
return true;
|
||||
}
|
||||
if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
|
||||
return true;
|
||||
}
|
||||
if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
|
||||
(a.feature_id < b.feature_id)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Tree partition defined by traversing the tree to the leaf.
|
||||
int32 partition_id;
|
||||
|
||||
// Feature column id.
|
||||
int64 feature_id;
|
||||
|
||||
// Dimension within feature column.
|
||||
int32 dimension;
|
||||
};
|
||||
|
||||
template <typename GradientType, typename HessianType>
|
||||
class StatsAccumulatorResource : public boosted_trees::StampedResource {
|
||||
using StatsByPartition =
|
||||
std::map<PartitionKey, std::pair<GradientType, HessianType>,
|
||||
PartitionKey::Less>;
|
||||
|
||||
public:
|
||||
StatsAccumulatorResource(const TensorShape& gradient_shape,
|
||||
const TensorShape& hessian_shape)
|
||||
: gradient_shape_(gradient_shape),
|
||||
hessian_shape_(hessian_shape),
|
||||
num_updates_(0) {
|
||||
// If GradientType/HessianType is scalar float then the shapes should be
|
||||
// scalar and vice versa.
|
||||
CHECK_EQ((std::is_same<GradientType, float>::value),
|
||||
TensorShapeUtils::IsScalar(gradient_shape));
|
||||
CHECK_EQ((std::is_same<HessianType, float>::value),
|
||||
TensorShapeUtils::IsScalar(hessian_shape));
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
|
||||
"]");
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
values_.clear();
|
||||
num_updates_ = 0;
|
||||
}
|
||||
|
||||
tensorflow::mutex* mutex() { return &mu_; }
|
||||
StatsByPartition* mutable_values() { return &values_; }
|
||||
const StatsByPartition& values() const { return values_; }
|
||||
const int64& num_updates() const { return num_updates_; }
|
||||
void set_num_updates(int64 val) { num_updates_ = val; }
|
||||
const TensorShape& gradient_shape() const { return gradient_shape_; }
|
||||
const TensorShape& hessian_shape() const { return hessian_shape_; }
|
||||
|
||||
private:
|
||||
// Key into a specific partition to accumulate stats for the specified feature
|
||||
// id.
|
||||
StatsByPartition values_;
|
||||
const TensorShape gradient_shape_;
|
||||
const TensorShape hessian_shape_;
|
||||
int64 num_updates_;
|
||||
tensorflow::mutex mu_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
|
||||
};
|
||||
|
||||
using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
|
||||
using StatsAccumulatorTensorResource =
|
||||
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
|
||||
|
||||
void SerializeScalarAccumulatorToOutput(
|
||||
const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
|
||||
int64 num_slots = resource.values().size();
|
||||
Tensor* partition_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||
TensorShape({num_slots}),
|
||||
&partition_ids_t));
|
||||
auto partition_ids = partition_ids_t->vec<int32>();
|
||||
|
||||
// Feature ids tensor has ids of feature columns and their dimensions.
|
||||
Tensor* feature_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
|
||||
TensorShape({num_slots, 2}),
|
||||
&feature_ids_t));
|
||||
auto feature_ids = feature_ids_t->matrix<int64>();
|
||||
|
||||
Tensor* gradients_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(
|
||||
"output_gradients", TensorShape({num_slots}), &gradients_t));
|
||||
auto gradients = gradients_t->vec<float>();
|
||||
|
||||
Tensor* hessians_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("output_hessians",
|
||||
TensorShape({num_slots}), &hessians_t));
|
||||
auto hessians = hessians_t->vec<float>();
|
||||
|
||||
int i = 0;
|
||||
for (const auto& iter : resource.values()) {
|
||||
partition_ids(i) = iter.first.partition_id;
|
||||
feature_ids(i, 0) = iter.first.feature_id;
|
||||
feature_ids(i, 1) = iter.first.dimension;
|
||||
|
||||
gradients(i) = iter.second.first;
|
||||
hessians(i) = iter.second.second;
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
void SerializeTensorAccumulatorToOutput(
|
||||
const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
|
||||
int64 num_slots = resource.values().size();
|
||||
Tensor* partition_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
|
||||
TensorShape({num_slots}),
|
||||
&partition_ids_t));
|
||||
auto partition_ids = partition_ids_t->vec<int32>();
|
||||
|
||||
Tensor* feature_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
|
||||
TensorShape({num_slots, 2}),
|
||||
&feature_ids_t));
|
||||
auto feature_ids = feature_ids_t->matrix<int64>();
|
||||
|
||||
TensorShape gradient_shape = resource.gradient_shape();
|
||||
int64 num_gradient_elements = gradient_shape.num_elements();
|
||||
gradient_shape.InsertDim(0, num_slots);
|
||||
Tensor* gradients_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("output_gradients", gradient_shape,
|
||||
&gradients_t));
|
||||
auto gradients = gradients_t->flat_outer_dims<float>();
|
||||
|
||||
TensorShape hessian_shape = resource.hessian_shape();
|
||||
int64 num_hessian_elements = hessian_shape.num_elements();
|
||||
hessian_shape.InsertDim(0, num_slots);
|
||||
Tensor* hessians_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
|
||||
hessian_shape, &hessians_t));
|
||||
auto hessians = hessians_t->flat_outer_dims<float>();
|
||||
|
||||
int i = 0;
|
||||
for (const auto& iter : resource.values()) {
|
||||
partition_ids(i) = iter.first.partition_id;
|
||||
feature_ids(i, 0) = iter.first.feature_id;
|
||||
feature_ids(i, 1) = iter.first.dimension;
|
||||
|
||||
for (int j = 0; j < num_gradient_elements; ++j) {
|
||||
gradients(i, j) = iter.second.first[j];
|
||||
}
|
||||
for (int j = 0; j < num_hessian_elements; ++j) {
|
||||
hessians(i, j) = iter.second.second[j];
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
void AddToScalarAccumulator(
|
||||
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||
const Tensor& gradients_t, const Tensor& hessians_t) {
|
||||
resource->set_num_updates(resource->num_updates() + 1);
|
||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
|
||||
const auto& gradients = gradients_t.vec<float>();
|
||||
const auto& hessians = hessians_t.vec<float>();
|
||||
|
||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||
auto stats_map = resource->mutable_values();
|
||||
for (int64 i = 0; i < num_updates; ++i) {
|
||||
const auto key =
|
||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||
feature_ids_and_dimensions(i, 1));
|
||||
auto itr = stats_map->find(key);
|
||||
if (itr != stats_map->end()) {
|
||||
itr->second.first += gradients(i);
|
||||
itr->second.second += hessians(i);
|
||||
} else {
|
||||
(*stats_map)[key] = {gradients(i), hessians(i)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AddToScalarAccumulator(
|
||||
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
|
||||
OpKernelContext* context) {
|
||||
const Tensor* partition_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||
const Tensor* feature_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
|
||||
const Tensor* gradients_t;
|
||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||
*gradients_t, *hessians_t);
|
||||
}
|
||||
|
||||
void AddToTensorAccumulator(
|
||||
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
|
||||
const Tensor& gradients_t, const Tensor& hessians_t,
|
||||
OpKernelContext* context) {
|
||||
resource->set_num_updates(resource->num_updates() + 1);
|
||||
|
||||
const TensorShape& partition_ids_shape = partition_ids_t.shape();
|
||||
const auto& partition_ids = partition_ids_t.vec<int32>();
|
||||
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
|
||||
TensorShape gradients_shape = gradients_t.shape();
|
||||
const auto& gradients = gradients_t.flat_outer_dims<float>();
|
||||
TensorShape hessians_shape = hessians_t.shape();
|
||||
const auto& hessians = hessians_t.flat_outer_dims<float>();
|
||||
|
||||
gradients_shape.RemoveDim(0);
|
||||
hessians_shape.RemoveDim(0);
|
||||
|
||||
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
|
||||
OP_REQUIRES(
|
||||
context, gradients_shape == resource->gradient_shape(),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Gradients dimensions must match: ", gradients_shape.DebugString(),
|
||||
", ", resource->gradient_shape().DebugString())));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, hessians_shape == resource->hessian_shape(),
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
|
||||
resource->hessian_shape().DebugString())));
|
||||
|
||||
int64 num_updates = partition_ids_shape.dim_size(0);
|
||||
auto stats_map = resource->mutable_values();
|
||||
for (int64 i = 0; i < num_updates; ++i) {
|
||||
const auto key =
|
||||
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
|
||||
feature_ids_and_dimensions(i, 1));
|
||||
auto itr = stats_map->find(key);
|
||||
if (itr == stats_map->end()) {
|
||||
std::vector<float> new_gradients(gradients_shape.num_elements());
|
||||
for (int j = 0; j < gradients_shape.num_elements(); ++j) {
|
||||
new_gradients[j] = gradients(i, j);
|
||||
}
|
||||
std::vector<float> new_hessians(hessians_shape.num_elements());
|
||||
for (int j = 0; j < hessians_shape.num_elements(); ++j) {
|
||||
new_hessians[j] = hessians(i, j);
|
||||
}
|
||||
(*stats_map)[key] = {new_gradients, new_hessians};
|
||||
} else {
|
||||
auto& stored_gradients = itr->second.first;
|
||||
for (int j = 0; j < gradients_shape.num_elements(); ++j) {
|
||||
stored_gradients[j] += gradients(i, j);
|
||||
}
|
||||
auto& stored_hessians = itr->second.second;
|
||||
for (int j = 0; j < hessians_shape.num_elements(); ++j) {
|
||||
stored_hessians[j] += hessians(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AddToTensorAccumulator(
|
||||
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
|
||||
OpKernelContext* context) {
|
||||
const Tensor* partition_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
|
||||
const Tensor* feature_ids_t;
|
||||
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
|
||||
const Tensor* gradients_t;
|
||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
|
||||
*gradients_t, *hessians_t, context);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
|
||||
REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<StatsAccumulatorScalarResource>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
|
||||
IsResourceInitialized<StatsAccumulatorTensorResource>);
|
||||
|
||||
class CreateStatsAccumulatorScalarOp : public OpKernel {
|
||||
public:
|
||||
explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
|
||||
TensorShape gradient_shape = TensorShape({});
|
||||
TensorShape hessian_shape = TensorShape({});
|
||||
|
||||
auto* result =
|
||||
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
|
||||
result->set_stamp(stamp_token_t->scalar<int64>()());
|
||||
// Only create one, if one does not exist already. Report status for all
|
||||
// other exceptions. If one already exists, it unrefs the new one.
|
||||
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
OP_REQUIRES(context, false, status);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
|
||||
CreateStatsAccumulatorScalarOp);
|
||||
|
||||
class CreateStatsAccumulatorTensorOp : public OpKernel {
|
||||
public:
|
||||
explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
|
||||
const Tensor* gradient_shape_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input("per_slot_gradient_shape", &gradient_shape_t));
|
||||
|
||||
const Tensor* hessian_shape_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("per_slot_hessian_shape", &hessian_shape_t));
|
||||
TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
|
||||
TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
|
||||
auto* result =
|
||||
new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
|
||||
result->set_stamp(stamp_token_t->scalar<int64>()());
|
||||
|
||||
// Only create one, if one does not exist already. Report status for all
|
||||
// other exceptions. If one already exists, it unrefs the new one.
|
||||
auto status = CreateResource(context, HandleFromInput(context, 0), result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
OP_REQUIRES(context, false, status);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
|
||||
CreateStatsAccumulatorTensorOp);
|
||||
|
||||
class StatsAccumulatorScalarAddOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
OpInputList resource_handle_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
|
||||
&resource_handle_list));
|
||||
OpInputList partition_ids_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("partition_ids", &partition_ids_list));
|
||||
|
||||
OpInputList feature_ids_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("feature_ids", &feature_ids_list));
|
||||
OpInputList gradients_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
|
||||
OpInputList hessians_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
boosted_trees::utils::ParallelFor(
|
||||
resource_handle_list.size(), worker_threads->NumThreads(),
|
||||
worker_threads,
|
||||
[&context, &resource_handle_list, &partition_ids_list,
|
||||
&feature_ids_list, &gradients_list, &hessians_list,
|
||||
stamp_token](int64 start, int64 end) {
|
||||
for (int resource_handle_idx = start; resource_handle_idx < end;
|
||||
++resource_handle_idx) {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!resource->is_stamp_valid(stamp_token)) {
|
||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << resource->stamp();
|
||||
return;
|
||||
}
|
||||
AddToScalarAccumulator(resource,
|
||||
partition_ids_list[resource_handle_idx],
|
||||
feature_ids_list[resource_handle_idx],
|
||||
gradients_list[resource_handle_idx],
|
||||
hessians_list[resource_handle_idx]);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
|
||||
StatsAccumulatorScalarAddOp);
|
||||
|
||||
class StatsAccumulatorTensorAddOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
OpInputList resource_handle_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
|
||||
&resource_handle_list));
|
||||
OpInputList partition_ids_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("partition_ids", &partition_ids_list));
|
||||
|
||||
OpInputList feature_ids_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("feature_ids", &feature_ids_list));
|
||||
OpInputList gradients_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
|
||||
OpInputList hessians_list;
|
||||
OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
thread::ThreadPool* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->workers;
|
||||
boosted_trees::utils::ParallelFor(
|
||||
resource_handle_list.size(), worker_threads->NumThreads(),
|
||||
worker_threads,
|
||||
[&context, &resource_handle_list, &partition_ids_list,
|
||||
&feature_ids_list, &gradients_list, &hessians_list,
|
||||
stamp_token](int64 start, int64 end) {
|
||||
for (int resource_handle_idx = start; resource_handle_idx < end;
|
||||
++resource_handle_idx) {
|
||||
const ResourceHandle& handle =
|
||||
resource_handle_list[resource_handle_idx]
|
||||
.flat<ResourceHandle>()(0);
|
||||
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// If the stamp is invalid we drop the update.
|
||||
if (!resource->is_stamp_valid(stamp_token)) {
|
||||
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
|
||||
<< "Passed stamp token: " << stamp_token << " "
|
||||
<< "Current token: " << resource->stamp();
|
||||
return;
|
||||
}
|
||||
AddToTensorAccumulator(resource,
|
||||
partition_ids_list[resource_handle_idx],
|
||||
feature_ids_list[resource_handle_idx],
|
||||
gradients_list[resource_handle_idx],
|
||||
hessians_list[resource_handle_idx], context);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
|
||||
StatsAccumulatorTensorAddOp);
|
||||
|
||||
class StatsAccumulatorScalarFlushOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||
// only Chief should call this function and chief is guaranteed to be in
|
||||
// a consistent state.
|
||||
CHECK(resource->is_stamp_valid(stamp_token));
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kNextStampTokenName, &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
|
||||
resource->Clear();
|
||||
resource->set_stamp(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
|
||||
StatsAccumulatorScalarFlushOp);
|
||||
|
||||
class StatsAccumulatorTensorFlushOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input(kNextStampTokenName, &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
|
||||
// If the stamp is invalid we restart the PS. It shouldn't happen since
|
||||
// only Chief should call this function and chief is guaranteed to be in
|
||||
// a consistent state.
|
||||
CHECK(resource->is_stamp_valid(stamp_token));
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
resource->Clear();
|
||||
resource->set_stamp(next_stamp_token);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
|
||||
StatsAccumulatorTensorFlushOp);
|
||||
|
||||
class StatsAccumulatorScalarDeserializeOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// Check the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
resource->Clear();
|
||||
resource->set_stamp(stamp_token);
|
||||
AddToScalarAccumulator(resource, context);
|
||||
const Tensor* num_updates_t;
|
||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
|
||||
StatsAccumulatorScalarDeserializeOp);
|
||||
|
||||
class StatsAccumulatorTensorDeserializeOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
|
||||
// Check the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
resource->Clear();
|
||||
resource->set_stamp(stamp_token);
|
||||
AddToTensorAccumulator(resource, context);
|
||||
const Tensor* num_updates_t;
|
||||
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
|
||||
resource->set_num_updates(num_updates_t->scalar<int64>()());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
|
||||
StatsAccumulatorTensorDeserializeOp);
|
||||
|
||||
class StatsAccumulatorScalarSerializeOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("stamp_token", TensorShape({}),
|
||||
&stamp_token_t));
|
||||
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
|
||||
StatsAccumulatorScalarSerializeOp);
|
||||
|
||||
class StatsAccumulatorTensorSerializeOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&resource));
|
||||
mutex_lock l(*resource->mutex());
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
Tensor* stamp_token_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("stamp_token", TensorShape({}),
|
||||
&stamp_token_t));
|
||||
stamp_token_t->scalar<int64>()() = resource->stamp();
|
||||
|
||||
Tensor* num_updates_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("num_updates", TensorShape({}),
|
||||
&num_updates_t));
|
||||
num_updates_t->scalar<int64>()() = resource->num_updates();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
|
||||
StatsAccumulatorTensorSerializeOp);
|
||||
|
||||
class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
TensorShape gradient_shape = TensorShape({});
|
||||
TensorShape hessian_shape = TensorShape({});
|
||||
core::RefCountPtr<StatsAccumulatorScalarResource> resource(
|
||||
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
|
||||
// Check the stamp token.
|
||||
AddToScalarAccumulator(resource, context);
|
||||
SerializeScalarAccumulatorToOutput(*resource, context);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
|
||||
StatsAccumulatorScalarMakeSummaryOp);
|
||||
|
||||
class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
|
||||
public:
|
||||
explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* gradients_t;
|
||||
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
|
||||
TensorShape gradients_shape = gradients_t->shape();
|
||||
gradients_shape.RemoveDim(0);
|
||||
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
TensorShape hessians_shape = hessians_t->shape();
|
||||
hessians_shape.RemoveDim(0);
|
||||
|
||||
core::RefCountPtr<StatsAccumulatorTensorResource> resource(
|
||||
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
|
||||
// Check the stamp token.
|
||||
AddToTensorAccumulator(resource, context);
|
||||
SerializeTensorAccumulatorToOutput(*resource, context);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
|
||||
StatsAccumulatorTensorMakeSummaryOp);
|
||||
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -1,996 +0,0 @@
|
||||
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
|
||||
|
||||
namespace boosted_trees {
|
||||
|
||||
namespace {
|
||||
|
||||
using boosted_trees::learner::LearnerConfig;
|
||||
using boosted_trees::learner::LearningRateConfig;
|
||||
using boosted_trees::models::DecisionTreeEnsembleResource;
|
||||
using boosted_trees::trees::DecisionTreeConfig;
|
||||
using boosted_trees::trees::Leaf;
|
||||
using boosted_trees::trees::TreeNode;
|
||||
using boosted_trees::trees::TreeNodeMetadata;
|
||||
using boosted_trees::utils::DropoutUtils;
|
||||
|
||||
// SplitCandidate holds the split candidate node along with the stats.
|
||||
struct SplitCandidate {
|
||||
// Id of handler that generated the split candidate.
|
||||
int64 handler_id;
|
||||
|
||||
// Split gain.
|
||||
float gain;
|
||||
|
||||
// Split info.
|
||||
learner::SplitInfo split_info;
|
||||
|
||||
// Oblivious split info.
|
||||
learner::ObliviousSplitInfo oblivious_split_info;
|
||||
};
|
||||
|
||||
// Checks that the leaf is not empty.
|
||||
bool IsLeafWellFormed(const Leaf& leaf) {
|
||||
return leaf.has_sparse_vector() || leaf.has_vector();
|
||||
}
|
||||
|
||||
// Helper method to update the best split per partition given
|
||||
// a current candidate.
|
||||
void UpdateBestSplit(
|
||||
const boosted_trees::learner::LearnerConfig& learner_config,
|
||||
int32 partition_id, SplitCandidate* split,
|
||||
std::map<int32, SplitCandidate>* best_splits) {
|
||||
// Don't consider nodeless splits.
|
||||
if (TF_PREDICT_FALSE(split->split_info.split_node().node_case() ==
|
||||
TreeNode::NODE_NOT_SET)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't consider negative splits if we're pre-pruning the tree.
|
||||
// Note that zero-gain splits are acceptable as they're mostly doing as well
|
||||
// as what bias centering in that partition would do.
|
||||
if (learner_config.pruning_mode() ==
|
||||
boosted_trees::learner::LearnerConfig::PRE_PRUNE &&
|
||||
split->gain < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the current node is pure, one of the leafs will be empty, so the split
|
||||
// is meaningless and we should not split.
|
||||
if (!(IsLeafWellFormed(split->split_info.right_child()) &&
|
||||
IsLeafWellFormed(split->split_info.left_child()))) {
|
||||
VLOG(1) << "Split does not actually split anything";
|
||||
return;
|
||||
}
|
||||
|
||||
// Take the split if we don't have a candidate yet.
|
||||
auto best_split_it = best_splits->find(partition_id);
|
||||
if (best_split_it == best_splits->end()) {
|
||||
best_splits->insert(std::make_pair(partition_id, std::move(*split)));
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine if best split so far needs to be replaced.
|
||||
SplitCandidate& best_split = best_split_it->second;
|
||||
if (TF_PREDICT_FALSE(split->gain == best_split.gain)) {
|
||||
// Tie break on node case preferring simpler tree node types.
|
||||
VLOG(2) << "Attempting to tie break with smaller node case. "
|
||||
<< "(current split: " << split->split_info.split_node().node_case()
|
||||
<< ", best split: "
|
||||
<< best_split.split_info.split_node().node_case() << ")";
|
||||
if (split->split_info.split_node().node_case() <
|
||||
best_split.split_info.split_node().node_case()) {
|
||||
best_split = std::move(*split);
|
||||
} else if (split->split_info.split_node().node_case() ==
|
||||
best_split.split_info.split_node().node_case()) {
|
||||
// Tie break on handler Id.
|
||||
VLOG(2) << "Tie breaking with higher handler Id. "
|
||||
<< "(current split: " << split->handler_id
|
||||
<< ", best split: " << best_split.handler_id << ")";
|
||||
if (split->handler_id > best_split.handler_id) {
|
||||
best_split = std::move(*split);
|
||||
}
|
||||
}
|
||||
} else if (split->gain > best_split.gain) {
|
||||
best_split = std::move(*split);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method to check whether a node is a terminal node in that it
|
||||
// only has leaf nodes as children.
|
||||
bool IsTerminalSplitNode(const size_t node_id,
|
||||
const std::vector<int32>& children,
|
||||
const std::vector<TreeNode>& nodes) {
|
||||
for (const int32 child_id : children) {
|
||||
const auto& child_node = nodes[child_id];
|
||||
CHECK(child_node.node_case() != TreeNode::NODE_NOT_SET);
|
||||
if (child_node.node_case() != TreeNode::kLeaf) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper method to recursively prune the tree in a depth-first fashion.
|
||||
void RecursivePruneTree(const size_t node_id, std::vector<TreeNode>* nodes) {
|
||||
// Base case when we reach a leaf.
|
||||
TreeNode& tree_node = (*nodes)[node_id];
|
||||
CHECK(tree_node.node_case() != TreeNode::NODE_NOT_SET);
|
||||
if (tree_node.node_case() == TreeNode::kLeaf) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Traverse node children first and recursively prune their sub-trees.
|
||||
const std::vector<int32> children =
|
||||
boosted_trees::trees::DecisionTree::GetChildren(tree_node);
|
||||
for (const int32 child_id : children) {
|
||||
RecursivePruneTree(child_id, nodes);
|
||||
}
|
||||
|
||||
// Two conditions must be satisfied to prune the node:
|
||||
// 1- The split gain is negative.
|
||||
// 2- After depth-first pruning, the node only has leaf children.
|
||||
TreeNodeMetadata* node_metadata = tree_node.mutable_node_metadata();
|
||||
if (node_metadata->gain() < 0 &&
|
||||
IsTerminalSplitNode(node_id, children, (*nodes))) {
|
||||
// Clear node children.
|
||||
for (const int32 child_id : children) {
|
||||
auto& child_node = (*nodes)[child_id];
|
||||
child_node.Clear();
|
||||
}
|
||||
|
||||
// Change node back into leaf.
|
||||
(*tree_node.mutable_leaf()) = *node_metadata->mutable_original_leaf();
|
||||
|
||||
// Clear gain for leaf node.
|
||||
tree_node.clear_node_metadata();
|
||||
} else {
|
||||
// Clear original leaf as it's no longer needed for back-track pruning.
|
||||
node_metadata->clear_original_leaf();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class CenterTreeEnsembleBiasOp : public OpKernel {
|
||||
public:
|
||||
explicit CenterTreeEnsembleBiasOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
// Read learner config.
|
||||
string serialized_learner_config;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("learner_config",
|
||||
&serialized_learner_config));
|
||||
OP_REQUIRES(context,
|
||||
learner_config_.ParseFromString(serialized_learner_config),
|
||||
errors::InvalidArgument("Unable to parse learner config."));
|
||||
|
||||
// Read centering epsilon.
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("centering_epsilon", ¢ering_epsilon_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Only the Chief should run this Op and it is guaranteed to be in
|
||||
// a consistent state so the stamps must always match.
|
||||
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
|
||||
|
||||
// Get the next stamp token.
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("next_stamp_token", &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
|
||||
// Update the ensemble stamp.
|
||||
ensemble_resource->set_stamp(next_stamp_token);
|
||||
|
||||
// Get the delta updates.
|
||||
const Tensor* delta_updates_t;
|
||||
OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t));
|
||||
auto delta_updates = delta_updates_t->vec<float>();
|
||||
const int64 logits_dimension = delta_updates_t->dim_size(0);
|
||||
|
||||
// Get the bias.
|
||||
boosted_trees::trees::Leaf* const bias =
|
||||
RetrieveBias(ensemble_resource, logits_dimension);
|
||||
CHECK(bias->has_vector());
|
||||
|
||||
// Update the bias.
|
||||
float total_delta = 0;
|
||||
auto* bias_vec = bias->mutable_vector();
|
||||
for (size_t idx = 0; idx < bias->vector().value_size(); ++idx) {
|
||||
float delta = delta_updates(idx);
|
||||
bias_vec->set_value(idx, bias_vec->value(idx) + delta);
|
||||
total_delta += std::abs(delta);
|
||||
}
|
||||
|
||||
// Make a centering continuation decision based on current update.
|
||||
bool continue_centering = total_delta > centering_epsilon_;
|
||||
if (continue_centering) {
|
||||
VLOG(1) << "Continuing to center bias, delta=" << total_delta;
|
||||
} else {
|
||||
VLOG(1) << "Done centering bias, delta=" << total_delta;
|
||||
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
|
||||
}
|
||||
Tensor* continue_centering_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("continue_centering", TensorShape({}),
|
||||
&continue_centering_t));
|
||||
continue_centering_t->scalar<bool>()() = continue_centering;
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper method to retrieve the bias from the tree ensemble.
|
||||
Leaf* RetrieveBias(
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
int64 logits_dimension) {
|
||||
const int32 num_trees = ensemble_resource->num_trees();
|
||||
if (num_trees <= 0) {
|
||||
// Add a new bias leaf.
|
||||
ensemble_resource->IncrementAttempts();
|
||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||
ensemble_resource->AddNewTree(1.0);
|
||||
auto* const leaf = tree_config->add_nodes()->mutable_leaf();
|
||||
for (size_t idx = 0; idx < logits_dimension; ++idx) {
|
||||
leaf->mutable_vector()->add_value(0.0);
|
||||
}
|
||||
return leaf;
|
||||
} else if (num_trees == 1) {
|
||||
// Confirms that the only tree is a bias and returns its leaf.
|
||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||
ensemble_resource->LastTree();
|
||||
CHECK_EQ(tree_config->nodes_size(), 1);
|
||||
CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf);
|
||||
return tree_config->mutable_nodes(0)->mutable_leaf();
|
||||
} else {
|
||||
LOG(FATAL) << "Unable to center bias on an already grown ensemble";
|
||||
}
|
||||
}
|
||||
|
||||
boosted_trees::learner::LearnerConfig learner_config_;
|
||||
float centering_epsilon_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CenterTreeEnsembleBias").Device(DEVICE_CPU),
|
||||
CenterTreeEnsembleBiasOp);
|
||||
|
||||
class GrowTreeEnsembleOp : public OpKernel {
|
||||
public:
|
||||
explicit GrowTreeEnsembleOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
// Read number of handlers, note that this is the static number of
|
||||
// all handlers but any subset of these handlers may be active at a time.
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_handlers", &num_handlers_));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("center_bias", ¢er_bias_));
|
||||
|
||||
// Read learner config.
|
||||
string serialized_learner_config;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("learner_config",
|
||||
&serialized_learner_config));
|
||||
OP_REQUIRES(context,
|
||||
learner_config_.ParseFromString(serialized_learner_config),
|
||||
errors::InvalidArgument("Unable to parse learner config."));
|
||||
|
||||
// Determine whether dropout was used when building this tree.
|
||||
if (learner_config_.has_learning_rate_tuner() &&
|
||||
learner_config_.learning_rate_tuner().tuner_case() ==
|
||||
LearningRateConfig::kDropout) {
|
||||
dropout_config_ = learner_config_.learning_rate_tuner().dropout();
|
||||
dropout_was_applied_ = true;
|
||||
} else {
|
||||
dropout_was_applied_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
mutex_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Only the Chief should run this Op and it is guaranteed to be in
|
||||
// a consistent state so the stamps must always match.
|
||||
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
|
||||
|
||||
// Get the next stamp token.
|
||||
const Tensor* next_stamp_token_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("next_stamp_token", &next_stamp_token_t));
|
||||
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
|
||||
CHECK(stamp_token != next_stamp_token);
|
||||
|
||||
// Update the ensemble stamp regardless of whether a layer
|
||||
// or tree is actually grown.
|
||||
ensemble_resource->set_stamp(next_stamp_token);
|
||||
|
||||
// Read the learning_rate.
|
||||
const Tensor* learning_rate_t;
|
||||
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
|
||||
float learning_rate = learning_rate_t->scalar<float>()();
|
||||
|
||||
// Read the weak learner type to use.
|
||||
const Tensor* weak_learner_type_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("weak_learner_type", &weak_learner_type_t));
|
||||
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
|
||||
|
||||
const Tensor* seed_t;
|
||||
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
|
||||
// Cast seed to uint64.
|
||||
const uint64 dropout_seed = seed_t->scalar<int64>()();
|
||||
|
||||
// Read partition Ids, gains and split candidates.
|
||||
OpInputList partition_ids_list;
|
||||
OpInputList gains_list;
|
||||
OpInputList splits_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("partition_ids", &partition_ids_list));
|
||||
OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
|
||||
OP_REQUIRES_OK(context, context->input_list("splits", &splits_list));
|
||||
|
||||
// Increment attempt stats.
|
||||
ensemble_resource->IncrementAttempts();
|
||||
|
||||
// Find best splits for each active partition.
|
||||
std::map<int32, SplitCandidate> best_splits;
|
||||
switch (weak_learner_type) {
|
||||
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||
FindBestSplitsPerPartitionNormal(context, partition_ids_list,
|
||||
gains_list, splits_list, &best_splits);
|
||||
break;
|
||||
}
|
||||
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
|
||||
FindBestSplitOblivious(context, gains_list, splits_list, &best_splits);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// No-op if no new splits can be considered.
|
||||
if (best_splits.empty()) {
|
||||
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the max tree depth.
|
||||
const Tensor* max_tree_depth_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("max_tree_depth", &max_tree_depth_t));
|
||||
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
|
||||
// Update and retrieve the growable tree.
|
||||
// If the tree is fully built and dropout was applied, it also adjusts the
|
||||
// weights of dropped and the last tree.
|
||||
DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
|
||||
ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
|
||||
weak_learner_type);
|
||||
// Split tree nodes.
|
||||
switch (weak_learner_type) {
|
||||
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||
for (auto& split_entry : best_splits) {
|
||||
SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
|
||||
ensemble_resource);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
|
||||
SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
|
||||
}
|
||||
}
|
||||
// Post-prune finalized tree if needed.
|
||||
if (learner_config_.pruning_mode() ==
|
||||
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
|
||||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
||||
VLOG(2) << "Post-pruning finalized tree.";
|
||||
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
|
||||
LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
|
||||
}
|
||||
PruneTree(tree_config);
|
||||
|
||||
// If after post-pruning the whole tree has no gain, remove the tree
|
||||
// altogether from the ensemble.
|
||||
if (tree_config->nodes_size() <= 0) {
|
||||
ensemble_resource->RemoveLastTree();
|
||||
}
|
||||
|
||||
if ((ensemble_resource->num_trees() == 0 ||
|
||||
ensemble_resource->LastTreeMetadata()->is_finalized()) &&
|
||||
learner_config_.has_each_tree_start() &&
|
||||
learner_config_.each_tree_start().nodes_size() > 0) {
|
||||
DCHECK_GT(learner_config_.each_tree_start_num_layers(), 0);
|
||||
// Add new dummy tree
|
||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||
ensemble_resource->AddNewTree(learning_rate);
|
||||
VLOG(1) << "Adding a new forced tree";
|
||||
|
||||
*tree_config = learner_config_.each_tree_start();
|
||||
|
||||
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
|
||||
ensemble_resource->LastTreeMetadata();
|
||||
|
||||
tree_metadata->set_is_finalized(max_tree_depth <= 1);
|
||||
tree_metadata->set_num_tree_weight_updates(1);
|
||||
tree_metadata->set_num_layers_grown(
|
||||
learner_config_.each_tree_start_num_layers());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper method which effectively does a reduce over all split candidates
|
||||
// and finds the best split for each partition.
|
||||
void FindBestSplitsPerPartitionNormal(
|
||||
OpKernelContext* const context, const OpInputList& partition_ids_list,
|
||||
const OpInputList& gains_list, const OpInputList& splits_list,
|
||||
std::map<int32, SplitCandidate>* best_splits) {
|
||||
// Find best split per partition going through every feature candidate.
|
||||
// TODO(salehay): Is this worth parallelizing?
|
||||
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
|
||||
const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
|
||||
const auto& gains = gains_list[handler_id].vec<float>();
|
||||
const auto& splits = splits_list[handler_id].vec<tstring>();
|
||||
OP_REQUIRES(context, partition_ids.size() == gains.size(),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent partition Ids and gains tensors: ",
|
||||
partition_ids.size(), " != ", gains.size()));
|
||||
OP_REQUIRES(context, partition_ids.size() == splits.size(),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent partition Ids and splits tensors: ",
|
||||
partition_ids.size(), " != ", splits.size()));
|
||||
for (size_t candidate_idx = 0; candidate_idx < splits.size();
|
||||
++candidate_idx) {
|
||||
// Get current split candidate.
|
||||
const auto& partition_id = partition_ids(candidate_idx);
|
||||
const auto& gain = gains(candidate_idx);
|
||||
const auto& serialized_split = splits(candidate_idx);
|
||||
SplitCandidate split;
|
||||
split.handler_id = handler_id;
|
||||
split.gain = gain;
|
||||
OP_REQUIRES(context, split.split_info.ParseFromString(serialized_split),
|
||||
errors::InvalidArgument("Unable to parse split info."));
|
||||
|
||||
// Update best split for partition based on the current candidate.
|
||||
UpdateBestSplit(learner_config_, partition_id, &split, best_splits);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FindBestSplitOblivious(OpKernelContext* const context,
|
||||
const OpInputList& gains_list,
|
||||
const OpInputList& splits_list,
|
||||
std::map<int32, SplitCandidate>* best_splits) {
|
||||
// Find best split per partition going through every feature candidate.
|
||||
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
|
||||
const auto& gains = gains_list[handler_id].vec<float>();
|
||||
const auto& splits = splits_list[handler_id].vec<tstring>();
|
||||
OP_REQUIRES(context, gains.size() == 1,
|
||||
errors::InvalidArgument(
|
||||
"Gains size must be one for oblivious weak learner: ",
|
||||
gains.size(), " != ", 1));
|
||||
OP_REQUIRES(context, splits.size() == 1,
|
||||
errors::InvalidArgument(
|
||||
"Splits size must be one for oblivious weak learner: ",
|
||||
splits.size(), " != ", 1));
|
||||
// Get current split candidate.
|
||||
const auto& gain = gains(0);
|
||||
const auto& serialized_split = splits(0);
|
||||
SplitCandidate split;
|
||||
split.handler_id = handler_id;
|
||||
split.gain = gain;
|
||||
OP_REQUIRES(
|
||||
context, split.oblivious_split_info.ParseFromString(serialized_split),
|
||||
errors::InvalidArgument("Unable to parse oblivious split info."));
|
||||
|
||||
auto split_info = split.oblivious_split_info;
|
||||
CHECK(split_info.children_size() % 2 == 0)
|
||||
<< "The oblivious split should generate an even number of children: "
|
||||
<< split_info.children_size();
|
||||
|
||||
// If every node is pure, then we shouldn't split.
|
||||
bool only_pure_nodes = true;
|
||||
for (int idx = 0; idx < split_info.children_size(); idx += 2) {
|
||||
if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
|
||||
IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
|
||||
only_pure_nodes = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (only_pure_nodes) {
|
||||
VLOG(1) << "The oblivious split does not actually split anything.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't consider negative splits if we're pre-pruning the tree.
|
||||
if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
|
||||
gain < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Take the split if we don't have a candidate yet.
|
||||
auto best_split_it = best_splits->find(0);
|
||||
if (best_split_it == best_splits->end()) {
|
||||
best_splits->insert(std::make_pair(0, std::move(split)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Determine if we should update best split.
|
||||
SplitCandidate& best_split = best_split_it->second;
|
||||
trees::TreeNode current_node = split_info.split_node();
|
||||
trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
|
||||
if (TF_PREDICT_FALSE(gain == best_split.gain)) {
|
||||
// Tie break on node case preferring simpler tree node types.
|
||||
VLOG(2) << "Attempting to tie break with smaller node case. "
|
||||
<< "(current split: " << current_node.node_case()
|
||||
<< ", best split: " << best_node.node_case() << ")";
|
||||
if (current_node.node_case() < best_node.node_case()) {
|
||||
best_split = std::move(split);
|
||||
} else if (current_node.node_case() == best_node.node_case()) {
|
||||
// Tie break on handler Id.
|
||||
VLOG(2) << "Tie breaking with higher handler Id. "
|
||||
<< "(current split: " << handler_id
|
||||
<< ", best split: " << best_split.handler_id << ")";
|
||||
if (handler_id > best_split.handler_id) {
|
||||
best_split = std::move(split);
|
||||
}
|
||||
}
|
||||
} else if (gain > best_split.gain) {
|
||||
best_split = std::move(split);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateTreeWeightsIfDropout(
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const uint64 dropout_seed) {
|
||||
// It is possible that the tree was built with dropout. If it is the case,
|
||||
// we need to adjust the tree weight, or bail out.
|
||||
if (!dropout_was_applied_ ||
|
||||
!ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
||||
return;
|
||||
}
|
||||
const int32 num_trees = ensemble_resource->num_trees();
|
||||
|
||||
// Based on seed, figure out what trees were dropped before.
|
||||
std::unordered_set<int32> trees_not_to_drop;
|
||||
if (center_bias_) {
|
||||
trees_not_to_drop.insert(0);
|
||||
}
|
||||
// Last tree is the current tree that is built.
|
||||
const int32 current_tree = num_trees - 1;
|
||||
trees_not_to_drop.insert(current_tree);
|
||||
|
||||
// Since only chief builds the trees, we are sure that the other tree
|
||||
// weights didn't change.
|
||||
std::vector<float> weights = ensemble_resource->GetTreeWeights();
|
||||
std::vector<int32> dropped_trees;
|
||||
std::vector<float> dropped_trees_weights;
|
||||
const auto dropout_status = DropoutUtils::DropOutTrees(
|
||||
dropout_seed, dropout_config_, trees_not_to_drop, weights,
|
||||
&dropped_trees, &dropped_trees_weights);
|
||||
CHECK(dropout_status.ok())
|
||||
<< "Can't figure out what trees were dropped out before, error is "
|
||||
<< dropout_status.error_message();
|
||||
|
||||
// Now we have dropped trees, update their weights and the current tree
|
||||
// weight.
|
||||
if (!dropped_trees.empty()) {
|
||||
std::vector<int32> increment_num_updates(num_trees, 0);
|
||||
DropoutUtils::GetTreesWeightsForAddingTrees(
|
||||
dropped_trees, dropped_trees_weights, current_tree,
|
||||
1 /* only 1 tree was added */, &weights, &increment_num_updates);
|
||||
|
||||
// Update the weights and num of updates for trees.
|
||||
for (int i = 0; i < num_trees; ++i) {
|
||||
ensemble_resource->SetTreeWeight(i, weights[i],
|
||||
increment_num_updates[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper method to update the growable tree which is by definition the last
|
||||
// tree in the ensemble.
|
||||
DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
|
||||
const float learning_rate, const uint64 dropout_seed,
|
||||
const int32 max_tree_depth, const int32 weak_learner_type) {
|
||||
const auto num_trees = ensemble_resource->num_trees();
|
||||
if (num_trees <= 0 ||
|
||||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
||||
// Create a new tree with a no-op leaf.
|
||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||
ensemble_resource->AddNewTree(learning_rate);
|
||||
VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of "
|
||||
<< num_trees + 1 << " trees.";
|
||||
tree_config->add_nodes()->mutable_leaf();
|
||||
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
|
||||
ensemble_resource->LastTreeMetadata();
|
||||
tree_metadata->set_is_finalized(max_tree_depth <= 1);
|
||||
tree_metadata->set_num_tree_weight_updates(1);
|
||||
} else {
|
||||
// The growable tree is by definition the last tree in the ensemble.
|
||||
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
|
||||
ensemble_resource->LastTreeMetadata();
|
||||
const auto new_num_layers = tree_metadata->num_layers_grown() + 1;
|
||||
VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
|
||||
<< num_trees - 1 << " of ensemble of " << num_trees << " trees.";
|
||||
// Update growable tree metadata.
|
||||
tree_metadata->set_num_layers_grown(new_num_layers);
|
||||
tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth);
|
||||
}
|
||||
UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed);
|
||||
return ensemble_resource->LastTree();
|
||||
}
|
||||
|
||||
// Helper method to merge leaf weights as the tree is being grown.
|
||||
boosted_trees::trees::Leaf* MergeLeafWeights(
|
||||
const boosted_trees::trees::Leaf& source,
|
||||
boosted_trees::trees::Leaf* dest) {
|
||||
// Resolve leaf merging method based on how the trees are being grown.
|
||||
if (learner_config_.growing_mode() ==
|
||||
boosted_trees::learner::LearnerConfig::WHOLE_TREE) {
|
||||
// No merging occurs when building a whole tree at a time.
|
||||
return dest;
|
||||
}
|
||||
|
||||
if (dest->leaf_case() == boosted_trees::trees::Leaf::LEAF_NOT_SET) {
|
||||
// No merging is required. Just copy the source weights;
|
||||
*dest = source;
|
||||
return dest;
|
||||
}
|
||||
|
||||
// Handle leaf merging based on type.
|
||||
switch (source.leaf_case()) {
|
||||
case boosted_trees::trees::Leaf::kVector: {
|
||||
// No-op if source is empty
|
||||
const auto& src_vec = source.vector();
|
||||
if (src_vec.value_size() == 0) {
|
||||
break;
|
||||
}
|
||||
CHECK(source.leaf_case() == dest->leaf_case());
|
||||
|
||||
// Dense add leaf vectors.
|
||||
auto* dst_vec = dest->mutable_vector();
|
||||
CHECK(src_vec.value_size() == dst_vec->value_size());
|
||||
for (size_t idx = 0; idx < source.vector().value_size(); ++idx) {
|
||||
(*dst_vec->mutable_value()->Mutable(idx)) += src_vec.value(idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case boosted_trees::trees::Leaf::kSparseVector: {
|
||||
// No-op if source is empty
|
||||
const auto& src_vec = source.sparse_vector();
|
||||
CHECK(src_vec.value_size() == src_vec.index_size());
|
||||
if (src_vec.value_size() == 0) {
|
||||
break;
|
||||
}
|
||||
CHECK(source.leaf_case() == dest->leaf_case());
|
||||
|
||||
// Get mapping of dimension to value for destination.
|
||||
std::unordered_map<int32, float> dst_map;
|
||||
auto* dst_vec = dest->mutable_sparse_vector();
|
||||
CHECK(dst_vec->value_size() == dst_vec->index_size());
|
||||
dst_map.reserve(dst_vec->value_size());
|
||||
for (size_t idx = 0; idx < dst_vec->value_size(); ++idx) {
|
||||
dst_map[dst_vec->index(idx)] = dst_vec->value(idx);
|
||||
}
|
||||
// Sparse add source vector to destination vector.
|
||||
for (size_t idx = 0; idx < src_vec.value_size(); ++idx) {
|
||||
dst_map[src_vec.index(idx)] += src_vec.value(idx);
|
||||
}
|
||||
// Rebuild merged destination leaf.
|
||||
dst_vec->clear_index();
|
||||
dst_vec->clear_value();
|
||||
for (const auto& entry : dst_map) {
|
||||
dst_vec->add_index(entry.first);
|
||||
dst_vec->add_value(entry.second);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case boosted_trees::trees::Leaf::LEAF_NOT_SET: {
|
||||
// No-op as there is nothing to merge.
|
||||
break;
|
||||
}
|
||||
}
|
||||
return dest;
|
||||
}
|
||||
|
||||
// Helper method to split a tree node and append its respective
|
||||
// leaf children given the split candidate.
|
||||
void SplitTreeNode(
|
||||
const int32 node_id, SplitCandidate* split,
|
||||
DecisionTreeConfig* tree_config,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
// No-op if we have no real node.
|
||||
CHECK(node_id < tree_config->nodes_size())
|
||||
<< "Invalid node " << node_id << " to split.";
|
||||
// Ensure new split node is valid.
|
||||
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
|
||||
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
|
||||
<< "Unexpected node type to split "
|
||||
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
|
||||
<< ". Tree config: " << tree_config->DebugString();
|
||||
|
||||
// Add left leaf.
|
||||
int32 left_id = tree_config->nodes_size();
|
||||
(*tree_config->add_nodes()->mutable_leaf()) =
|
||||
*MergeLeafWeights(tree_config->nodes(node_id).leaf(),
|
||||
split->split_info.mutable_left_child());
|
||||
|
||||
// Add right leaf.
|
||||
int32 right_id = tree_config->nodes_size();
|
||||
(*tree_config->add_nodes()->mutable_leaf()) =
|
||||
*MergeLeafWeights(tree_config->nodes(node_id).leaf(),
|
||||
split->split_info.mutable_right_child());
|
||||
|
||||
// Link children and add them as new roots.
|
||||
boosted_trees::trees::DecisionTree::LinkChildren(
|
||||
{left_id, right_id}, split->split_info.mutable_split_node());
|
||||
|
||||
// Add split gain and, if needed, original leaf to node metadata.
|
||||
TreeNodeMetadata* node_metadata =
|
||||
split->split_info.mutable_split_node()->mutable_node_metadata();
|
||||
node_metadata->set_gain(split->gain);
|
||||
if (learner_config_.pruning_mode() ==
|
||||
boosted_trees::learner::LearnerConfig::POST_PRUNE) {
|
||||
(*node_metadata->mutable_original_leaf()) =
|
||||
*tree_config->mutable_nodes(node_id)->mutable_leaf();
|
||||
}
|
||||
|
||||
// Replace node in tree.
|
||||
(*tree_config->mutable_nodes(node_id)) =
|
||||
*split->split_info.mutable_split_node();
|
||||
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
|
||||
resource->MaybeAddUsedHandler(split->handler_id);
|
||||
}
|
||||
}
|
||||
|
||||
void SplitTreeLayer(
|
||||
SplitCandidate* split, DecisionTreeConfig* tree_config,
|
||||
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
|
||||
int depth = 0;
|
||||
while (depth < tree_config->nodes_size() &&
|
||||
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
|
||||
depth++;
|
||||
}
|
||||
CHECK(tree_config->nodes_size() > 0)
|
||||
<< "A tree must have at least one dummy leaf.";
|
||||
// The number of new children.
|
||||
int num_children = 1 << (depth + 1);
|
||||
auto split_info = split->oblivious_split_info;
|
||||
CHECK(num_children >= split_info.children_size())
|
||||
<< "Too many new children, expected <= " << num_children << " and got "
|
||||
<< split_info.children_size();
|
||||
std::vector<trees::Leaf> new_leaves;
|
||||
new_leaves.reserve(num_children);
|
||||
int next_id = 0;
|
||||
for (int idx = 0; idx < num_children / 2; idx++) {
|
||||
trees::Leaf old_leaf =
|
||||
*tree_config->mutable_nodes(depth + idx)->mutable_leaf();
|
||||
// Check if a split was made for this leaf.
|
||||
if (next_id < split_info.children_parent_id_size() &&
|
||||
depth + idx == split_info.children_parent_id(next_id)) {
|
||||
// Add left leaf.
|
||||
new_leaves.push_back(*MergeLeafWeights(
|
||||
old_leaf, split_info.mutable_children(2 * next_id)));
|
||||
// Add right leaf.
|
||||
new_leaves.push_back(*MergeLeafWeights(
|
||||
old_leaf, split_info.mutable_children(2 * next_id + 1)));
|
||||
next_id++;
|
||||
} else {
|
||||
// If there is no split for this leaf, just duplicate it.
|
||||
new_leaves.push_back(old_leaf);
|
||||
new_leaves.push_back(old_leaf);
|
||||
}
|
||||
}
|
||||
CHECK(next_id == split_info.children_parent_id_size());
|
||||
TreeNodeMetadata* split_metadata =
|
||||
split_info.mutable_split_node()->mutable_node_metadata();
|
||||
split_metadata->set_gain(split->gain);
|
||||
|
||||
TreeNode new_split = *split_info.mutable_split_node();
|
||||
// Move old children to metadata.
|
||||
for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
|
||||
*new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
|
||||
*tree_config->mutable_nodes(idx)->mutable_leaf();
|
||||
}
|
||||
// Add the new split to the tree_config in place before the children start.
|
||||
*tree_config->mutable_nodes(depth) = new_split;
|
||||
// Add the new children
|
||||
int nodes_size = tree_config->nodes_size();
|
||||
for (int idx = 0; idx < num_children; idx++) {
|
||||
if (idx + depth + 1 < nodes_size) {
|
||||
// Update leaves that were already there.
|
||||
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
|
||||
new_leaves[idx];
|
||||
} else {
|
||||
// Add new leaves.
|
||||
*tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
|
||||
// No-op if tree is empty.
|
||||
if (tree_config->nodes_size() <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Copy nodes to temp vector and clear original tree.
|
||||
std::vector<TreeNode> tree_nodes;
|
||||
tree_nodes.reserve(tree_config->nodes_size());
|
||||
for (auto& node : (*tree_config->mutable_nodes())) {
|
||||
tree_nodes.push_back(node);
|
||||
node.Clear();
|
||||
}
|
||||
tree_config->clear_nodes();
|
||||
|
||||
// Prune the tree recursively starting from the root.
|
||||
RecursivePruneTree(0, &tree_nodes);
|
||||
|
||||
// Rebuild compacted tree.
|
||||
(*tree_config->add_nodes()) = tree_nodes[0];
|
||||
std::unordered_map<size_t, size_t> nodes_map;
|
||||
nodes_map[0] = 0;
|
||||
for (size_t node_idx = 0; node_idx < tree_nodes.size(); ++node_idx) {
|
||||
// Skip pruned nodes.
|
||||
auto& original_node = tree_nodes[node_idx];
|
||||
if (original_node.node_case() == TreeNode::NODE_NOT_SET) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find node mapped in tree ensemble.
|
||||
auto mapped_node_it = nodes_map.find(node_idx);
|
||||
CHECK(mapped_node_it != nodes_map.end());
|
||||
auto& mapped_node = (*tree_config->mutable_nodes(mapped_node_it->second));
|
||||
|
||||
// Get node children
|
||||
auto children =
|
||||
boosted_trees::trees::DecisionTree::GetChildren(original_node);
|
||||
for (int32& child_idx : children) {
|
||||
auto new_idx = tree_config->nodes_size();
|
||||
(*tree_config->add_nodes()) = tree_nodes[child_idx];
|
||||
nodes_map[child_idx] = new_idx;
|
||||
child_idx = new_idx;
|
||||
}
|
||||
boosted_trees::trees::DecisionTree::LinkChildren(children, &mapped_node);
|
||||
}
|
||||
|
||||
// Check if there are any nodes with gain left.
|
||||
if (tree_config->nodes_size() == 1 &&
|
||||
tree_config->nodes(0).node_metadata().gain() <= 0) {
|
||||
// The whole tree should be pruned.
|
||||
VLOG(2) << "No useful nodes left after post-pruning tree.";
|
||||
tree_config->clear_nodes();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
boosted_trees::learner::LearnerConfig learner_config_;
|
||||
int64 num_handlers_;
|
||||
LearningRateDropoutDrivenConfig dropout_config_;
|
||||
bool dropout_was_applied_;
|
||||
bool center_bias_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("GrowTreeEnsemble").Device(DEVICE_CPU),
|
||||
GrowTreeEnsembleOp);
|
||||
|
||||
class TreeEnsembleStatsOp : public OpKernel {
|
||||
public:
|
||||
explicit TreeEnsembleStatsOp(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// Get decision tree ensemble.
|
||||
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
|
||||
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
|
||||
&ensemble_resource));
|
||||
tf_shared_lock l(*ensemble_resource->get_mutex());
|
||||
|
||||
// Get the stamp token.
|
||||
const Tensor* stamp_token_t;
|
||||
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
|
||||
int64 stamp_token = stamp_token_t->scalar<int64>()();
|
||||
|
||||
// Only the Chief should run this Op and it is guaranteed to be in
|
||||
// a consistent state so the stamps must always match.
|
||||
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
|
||||
const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config =
|
||||
ensemble_resource->decision_tree_ensemble();
|
||||
|
||||
// Set tree stats.
|
||||
Tensor* num_trees_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
"num_trees", TensorShape({}), &num_trees_t));
|
||||
Tensor* active_tree_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("active_tree", TensorShape({}),
|
||||
&active_tree_t));
|
||||
Tensor* attempted_tree_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("attempted_trees", TensorShape({}),
|
||||
&attempted_tree_t));
|
||||
|
||||
const int num_trees = ensemble_resource->num_trees();
|
||||
active_tree_t->scalar<int64>()() = num_trees;
|
||||
num_trees_t->scalar<int64>()() =
|
||||
(num_trees <= 0 ||
|
||||
ensemble_resource->LastTreeMetadata()->is_finalized())
|
||||
? num_trees
|
||||
: num_trees - 1;
|
||||
attempted_tree_t->scalar<int64>()() =
|
||||
ensemble_config.growing_metadata().num_trees_attempted();
|
||||
|
||||
// Set layer stats.
|
||||
Tensor* num_layers_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
"num_layers", TensorShape({}), &num_layers_t));
|
||||
Tensor* active_layer_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("active_layer", TensorShape({}),
|
||||
&active_layer_t));
|
||||
Tensor* attempted_layers_t = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("attempted_layers", TensorShape({}),
|
||||
&attempted_layers_t));
|
||||
|
||||
int64 num_layers = 0;
|
||||
for (const auto& tree_metadata : ensemble_config.tree_metadata()) {
|
||||
num_layers += tree_metadata.num_layers_grown();
|
||||
}
|
||||
num_layers_t->scalar<int64>()() = num_layers;
|
||||
int tree_metadata_size = ensemble_config.tree_metadata_size();
|
||||
active_layer_t->scalar<int64>()() =
|
||||
tree_metadata_size > 0
|
||||
? ensemble_config.tree_metadata(tree_metadata_size - 1)
|
||||
.num_layers_grown()
|
||||
: 0;
|
||||
attempted_layers_t->scalar<int64>()() =
|
||||
ensemble_config.growing_metadata().num_layers_attempted();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStats").Device(DEVICE_CPU),
|
||||
TreeEnsembleStatsOp);
|
||||
|
||||
} // namespace boosted_trees
|
||||
} // namespace tensorflow
|
@ -1,447 +0,0 @@
|
||||
# Description:
|
||||
# This directory contains common utilities used in boosted_trees.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_binary", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/contrib/boosted_trees:__subpackages__",
|
||||
"//tensorflow/contrib/boosted_trees:friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
# Utils
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
srcs = [
|
||||
"utils/batch_features.cc",
|
||||
"utils/dropout_utils.cc",
|
||||
"utils/examples_iterable.cc",
|
||||
"utils/parallel_for.cc",
|
||||
"utils/sparse_column_iterable.cc",
|
||||
"utils/tensor_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/batch_features.h",
|
||||
"utils/dropout_utils.h",
|
||||
"utils/example.h",
|
||||
"utils/examples_iterable.h",
|
||||
"utils/macros.h",
|
||||
"utils/optional_value.h",
|
||||
"utils/parallel_for.h",
|
||||
"utils/random.h",
|
||||
"utils/sparse_column_iterable.h",
|
||||
"utils/tensor_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sparse_column_iterable_test",
|
||||
size = "small",
|
||||
srcs = ["utils/sparse_column_iterable_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "examples_iterable_test",
|
||||
size = "small",
|
||||
srcs = ["utils/examples_iterable_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "example_test",
|
||||
size = "small",
|
||||
srcs = ["utils/example_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "batch_features_test",
|
||||
size = "small",
|
||||
srcs = ["utils/batch_features_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "dropout_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/dropout_utils_test.cc"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Models
|
||||
|
||||
cc_library(
|
||||
name = "models",
|
||||
srcs = ["models/multiple_additive_trees.cc"],
|
||||
hdrs = ["models/multiple_additive_trees.h"],
|
||||
deps = [
|
||||
":trees",
|
||||
":utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "multiple_additive_trees_test",
|
||||
size = "small",
|
||||
srcs = ["models/multiple_additive_trees_test.cc"],
|
||||
deps = [
|
||||
":batch_features_testutil",
|
||||
":models",
|
||||
":random_tree_gen",
|
||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Testutil
|
||||
|
||||
cc_library(
|
||||
name = "batch_features_testutil",
|
||||
testonly = 1,
|
||||
srcs = ["testutil/batch_features_testutil.cc"],
|
||||
hdrs = ["testutil/batch_features_testutil.h"],
|
||||
deps = [
|
||||
":utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "random_tree_gen",
|
||||
srcs = ["testutil/random_tree_gen.cc"],
|
||||
hdrs = ["testutil/random_tree_gen.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "random_tree_gen_main",
|
||||
srcs = ["testutil/random_tree_gen_main.cc"],
|
||||
deps = [
|
||||
":random_tree_gen",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# Quantiles
|
||||
|
||||
cc_library(
|
||||
name = "weighted_quantiles",
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"quantiles/weighted_quantiles_buffer.h",
|
||||
"quantiles/weighted_quantiles_stream.h",
|
||||
"quantiles/weighted_quantiles_summary.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "weighted_quantiles_buffer_test",
|
||||
size = "small",
|
||||
srcs = ["quantiles/weighted_quantiles_buffer_test.cc"],
|
||||
deps = [
|
||||
":weighted_quantiles",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "weighted_quantiles_summary_test",
|
||||
size = "small",
|
||||
srcs = ["quantiles/weighted_quantiles_summary_test.cc"],
|
||||
deps = [
|
||||
":weighted_quantiles",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "weighted_quantiles_stream_test",
|
||||
size = "small",
|
||||
srcs = ["quantiles/weighted_quantiles_stream_test.cc"],
|
||||
deps = [
|
||||
":weighted_quantiles",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Trees
|
||||
|
||||
cc_library(
|
||||
name = "trees",
|
||||
srcs = ["trees/decision_tree.cc"],
|
||||
hdrs = ["trees/decision_tree.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "trees_test",
|
||||
size = "small",
|
||||
srcs = ["trees/decision_tree_test.cc"],
|
||||
deps = [
|
||||
":trees",
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Learner/batch
|
||||
|
||||
py_library(
|
||||
name = "base_split_handler",
|
||||
srcs = ["learner/batch/base_split_handler.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees:batch_ops_utils_py",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "categorical_split_handler",
|
||||
srcs = ["learner/batch/categorical_split_handler.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees:split_handler_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "categorical_split_handler_test",
|
||||
srcs = ["learner/batch/categorical_split_handler_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":categorical_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ordinal_split_handler",
|
||||
srcs = ["learner/batch/ordinal_split_handler.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":base_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees:quantile_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees:split_handler_ops_py",
|
||||
"//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ordinal_split_handler_test",
|
||||
srcs = ["learner/batch/ordinal_split_handler_test.py"],
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ordinal_split_handler",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
|
||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:resources",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
# Learner/Common
|
||||
|
||||
cc_library(
|
||||
name = "class-partition-key",
|
||||
hdrs = ["learner/common/accumulators/class-partition-key.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "feature-stats-accumulator",
|
||||
hdrs = ["learner/common/accumulators/feature-stats-accumulator.h"],
|
||||
deps = [
|
||||
":class-partition-key",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "feature-stats-accumulator_test",
|
||||
size = "small",
|
||||
srcs = ["learner/common/accumulators/feature-stats-accumulator_test.cc"],
|
||||
deps = [
|
||||
":feature-stats-accumulator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "example_partitioner",
|
||||
srcs = ["learner/common/partitioners/example_partitioner.cc"],
|
||||
hdrs = ["learner/common/partitioners/example_partitioner.h"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/boosted_trees/lib:trees",
|
||||
"//tensorflow/contrib/boosted_trees/lib:utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "example_partitioner_test",
|
||||
size = "small",
|
||||
srcs = ["learner/common/partitioners/example_partitioner_test.cc"],
|
||||
deps = [
|
||||
":example_partitioner",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# Learner/stochastic
|
||||
cc_library(
|
||||
name = "gradient-stats",
|
||||
hdrs = ["learner/common/stats/gradient-stats.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node-stats",
|
||||
hdrs = ["learner/common/stats/node-stats.h"],
|
||||
deps = [
|
||||
":gradient-stats",
|
||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "split-stats",
|
||||
hdrs = ["learner/common/stats/split-stats.h"],
|
||||
deps = [
|
||||
":node-stats",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "feature-split-candidate",
|
||||
hdrs = ["learner/common/stats/feature-split-candidate.h"],
|
||||
deps = [
|
||||
":split-stats",
|
||||
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "node-stats_test",
|
||||
size = "small",
|
||||
srcs = ["learner/common/stats/node-stats_test.cc"],
|
||||
deps = [
|
||||
":node-stats",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user