[2.0] rm -rf tensorflow/contrib

PiperOrigin-RevId: 269816906
This commit is contained in:
Gunhan Gulsoy 2019-09-18 08:59:39 -07:00 committed by TensorFlower Gardener
parent 776b99925c
commit ffc25308ce
1879 changed files with 22 additions and 440894 deletions

View File

@ -849,7 +849,7 @@ py_library(
visibility = ["//visibility:public"],
deps = select({
"api_version_2": [],
"//conditions:default": ["//tensorflow/contrib:contrib_py"],
"//conditions:default": [],
}) + [
":tensorflow_py_no_contrib",
"//tensorflow/python/estimator:estimator_py",

View File

@ -1,175 +0,0 @@
# Description:
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
load("//tensorflow:tensorflow.bzl", "if_not_windows")
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "contrib_py",
srcs = glob(
["**/*.py"],
exclude = [
"**/*_test.py",
],
),
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/all_reduce",
"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/boosted_trees:init_py",
"//tensorflow/contrib/checkpoint/python:checkpoint",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/compiler:compiler_py",
"//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/autograph",
"//tensorflow/contrib/constrained_optimization",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
"//tensorflow/contrib/data",
"//tensorflow/contrib/deprecated:deprecated_py",
"//tensorflow/contrib/distribute:distribute",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/feature_column:feature_column_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/hadoop",
"//tensorflow/contrib/hooks",
"//tensorflow/contrib/image:distort_image_py",
"//tensorflow/contrib/image:image_py",
"//tensorflow/contrib/image:single_image_random_dot_stereograms_py",
"//tensorflow/contrib/input_pipeline:input_pipeline_py",
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/keras",
"//tensorflow/contrib/kernel_methods",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/learn:head_test_lib",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/libsvm",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/losses:metric_learning_py",
"//tensorflow/contrib/memory_stats:memory_stats_py",
"//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/mixed_precision:mixed_precision",
"//tensorflow/contrib/model_pruning",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/opt:opt_py",
"//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
"//tensorflow/contrib/periodic_resample:init_py",
"//tensorflow/contrib/predictor",
"//tensorflow/contrib/proto",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/quantize:quantize_graph",
"//tensorflow/contrib/receptive_field:receptive_field_py",
"//tensorflow/contrib/recurrent:recurrent_py",
"//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/rpc",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/contrib/seq2seq:seq2seq_py",
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/contrib/slim",
"//tensorflow/contrib/slim:nets",
"//tensorflow/contrib/solvers:solvers_py",
"//tensorflow/contrib/sparsemax:sparsemax_py",
"//tensorflow/contrib/specs",
"//tensorflow/contrib/staging",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/stateless",
"//tensorflow/contrib/summary:summary",
"//tensorflow/contrib/tensor_forest:init_py",
"//tensorflow/contrib/tensorboard",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/text:text_py",
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
] + select({
"//tensorflow:android": [],
"//tensorflow:ios": [],
"//tensorflow:linux_s390x": [],
"//tensorflow:windows": [],
"//conditions:default": [
"//tensorflow/contrib/fused_conv:fused_conv_py",
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
}) + select({
"//tensorflow:android": [],
"//tensorflow:ios": [],
"//tensorflow:linux_s390x": [],
"//tensorflow:windows": [],
"//tensorflow:no_gcp_support": [],
"//conditions:default": [
"//tensorflow/contrib/bigtable",
"//tensorflow/contrib/cloud:cloud_py",
],
}),
)
cc_library(
name = "contrib_kernels",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/image:image_ops_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels",
"//tensorflow/contrib/rnn:all_kernels",
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
"//tensorflow/contrib/tensor_forest:model_ops_kernels",
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
] + if_not_windows([
"//tensorflow/contrib/tensorrt:trt_op_kernels",
]),
)
cc_library(
name = "contrib_ops_op_lib",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib",
"//tensorflow/contrib/rnn:all_ops",
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
"//tensorflow/contrib/tensor_forest:model_ops_op_lib",
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops",
] + if_not_windows([
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
]),
)

View File

@ -1,23 +0,0 @@
# TensorFlow contrib
Any code in this directory is not officially supported, and may change or be
removed at any time without notice.
The contrib directory contains project directories, each of which has designated
owners. It is meant to contain features and contributions that eventually should
get merged into core TensorFlow, but whose interfaces may still change, or which
require some testing to see whether they can find broader acceptance. We are
trying to keep duplication within contrib to a minimum, so you may be asked to
refactor code in contrib to use some feature inside core or in another project
in contrib rather than reimplementing the feature.
When adding a project, please stick to the following directory structure:
Create a project directory in `contrib/`, and mirror the portions of the
TensorFlow tree that your project requires underneath `contrib/my_project/`.
For example, let's say you create foo ops in two files: `foo_ops.py` and
`foo_ops_test.py`. If you were to merge those files directly into TensorFlow,
they would live in `tensorflow/python/ops/foo_ops.py` and
`tensorflow/python/kernel_tests/foo_ops_test.py`. In `contrib/`, they are part
of project `foo`, and their full paths are `contrib/foo/python/ops/foo_ops.py`
and `contrib/foo/python/kernel_tests/foo_ops_test.py`.

View File

@ -1,122 +0,0 @@
# pylint: disable=g-import-not-at-top
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrib module containing volatile or experimental code.
Warning: The `tf.contrib` module will not be included in TensorFlow 2.0. Many
of its submodules have been integrated into TensorFlow core, or spun-off into
other projects like [`tensorflow_io`](https://github.com/tensorflow/io), or
[`tensorflow_addons`](https://github.com/tensorflow/addons). For instructions
on how to upgrade see the
[Migration guide](https://www.tensorflow.org/beta/guide/migration_guide).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import platform
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import autograph
from tensorflow.contrib import batching
from tensorflow.contrib import bayesflow
from tensorflow.contrib import checkpoint
from tensorflow.contrib import cluster_resolver
from tensorflow.contrib import compiler
from tensorflow.contrib import constrained_optimization
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
from tensorflow.contrib import cudnn_rnn
from tensorflow.contrib import data
from tensorflow.contrib import deprecated
from tensorflow.contrib import distribute
from tensorflow.contrib import distributions
from tensorflow.contrib import estimator
from tensorflow.contrib import factorization
from tensorflow.contrib import feature_column
from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import image
from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import keras
from tensorflow.contrib import kernel_methods
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
from tensorflow.contrib import memory_stats
from tensorflow.contrib import metrics
from tensorflow.contrib import mixed_precision
from tensorflow.contrib import model_pruning
from tensorflow.contrib import nn
from tensorflow.contrib import opt
from tensorflow.contrib import periodic_resample
from tensorflow.contrib import predictor
from tensorflow.contrib import proto
from tensorflow.contrib import quantization
from tensorflow.contrib import quantize
from tensorflow.contrib import reduce_slice_ops
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn
from tensorflow.contrib import rpc
from tensorflow.contrib import saved_model
from tensorflow.contrib import seq2seq
from tensorflow.contrib import signal
from tensorflow.contrib import slim
from tensorflow.contrib import solvers
from tensorflow.contrib import sparsemax
from tensorflow.contrib import staging
from tensorflow.contrib import stat_summarizer
from tensorflow.contrib import stateless
from tensorflow.contrib import tensor_forest
from tensorflow.contrib import tensorboard
from tensorflow.contrib import testing
from tensorflow.contrib import tfprof
from tensorflow.contrib import timeseries
from tensorflow.contrib import tpu
from tensorflow.contrib import training
from tensorflow.contrib import util
from tensorflow.contrib.eager.python import tfe as eager
from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2
from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
from tensorflow.contrib.recurrent.python import recurrent_api as recurrent
from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph
from tensorflow.contrib.specs import python as specs
from tensorflow.contrib.summary import summary
if os.name != "nt" and platform.machine() != "s390x":
try:
from tensorflow.contrib import cloud
except ImportError:
pass
from tensorflow.python.util.lazy_loader import LazyLoader
ffmpeg = LazyLoader("ffmpeg", globals(),
"tensorflow.contrib.ffmpeg")
del os
del platform
del LazyLoader
del absolute_import
del division
del print_function

View File

@ -1,33 +0,0 @@
# Description:
# All-reduce implementations.
# APIs are subject to change. Eventually to be replaced by equivalent
# functionality within TensorFlow core.
package(
default_visibility = ["//tensorflow:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "all_reduce_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":all_reduce",
"//tensorflow/python:util",
],
)
py_library(
name = "all_reduce",
srcs = [
"python/all_reduce.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python/distribute:all_reduce",
],
)

View File

@ -1,39 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All-reduce implementations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.all_reduce.python.all_reduce import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = [
'build_ring_all_reduce',
'build_recursive_hd_all_reduce',
'build_shuffle_all_reduce',
'build_nccl_all_reduce',
'build_nccl_then_ring',
'build_nccl_then_recursive_hd',
'build_nccl_then_shuffle',
'build_shuffle_then_ring',
'build_shuffle_then_shuffle'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -1,22 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.python.distribute.all_reduce import *

View File

@ -1,87 +0,0 @@
# Description:
# JNI-based Java inference interface for TensorFlow.
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_copts",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files([
"LICENSE",
"jni/version_script.lds",
])
filegroup(
name = "android_tensorflow_inference_jni_srcs",
srcs = glob([
"**/*.cc",
"**/*.h",
]),
visibility = ["//visibility:public"],
)
cc_library(
name = "android_tensorflow_inference_jni",
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/java/src/main/native",
],
alwayslink = 1,
)
# JAR with Java bindings to TF.
android_library(
name = "android_tensorflow_inference_java",
srcs = glob(["java/**/*.java"]) + ["//tensorflow/java:java_sources"],
tags = [
"manual",
"notap",
],
)
# Build the native .so.
# bazel build //tensorflow/contrib/android:libtensorflow_inference.so \
# --crosstool_top=//external:android/crosstool \
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
# --cpu=armeabi-v7a
LINKER_SCRIPT = "//tensorflow/contrib/android:jni/version_script.lds"
cc_binary(
name = "libtensorflow_inference.so",
copts = tf_copts() + [
"-ffunction-sections",
"-fdata-sections",
],
linkopts = if_android([
"-landroid",
"-latomic",
"-ldl",
"-llog",
"-lm",
"-z defs",
"-s",
"-Wl,--gc-sections",
"-Wl,--version-script,$(location {})".format(LINKER_SCRIPT),
]),
linkshared = 1,
linkstatic = 1,
tags = [
"manual",
"notap",
],
deps = [
":android_tensorflow_inference_jni",
"//tensorflow/core:android_tensorflow_lib",
LINKER_SCRIPT,
],
)

View File

@ -1,95 +0,0 @@
# Android TensorFlow support
This directory defines components (a native `.so` library and a Java JAR)
geared towards supporting TensorFlow on Android. This includes:
- The [TensorFlow Java API](../../java/README.md)
- A `TensorFlowInferenceInterface` class that provides a smaller API
surface suitable for inference and summarizing performance of model execution.
For example usage, see [TensorFlowImageClassifier.java](../../examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java)
in the [TensorFlow Android Demo](../../examples/android).
For prebuilt libraries, see the
[nightly Android build artifacts](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)
page for a recent build.
The TensorFlow Inference Interface is also available as a
[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
(see the tensorflow-android directory) and can be included quite simply in your
android project with a couple of lines in the project's `build.gradle` file:
```
allprojects {
repositories {
jcenter()
}
}
dependencies {
compile 'org.tensorflow:tensorflow-android:+'
}
```
This will tell Gradle to use the
[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
of the TensorFlow AAR that has been released to
[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/).
You may replace the `+` with an explicit version label if you wish to
use a specific release of TensorFlow in your app.
To build the libraries yourself (if, for example, you want to support custom
TensorFlow operators), pick your preferred approach below:
### Bazel
First follow the Bazel setup instructions described in
[tensorflow/examples/android/README.md](../../examples/android/README.md)
Then, to build the native TF library:
```
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cxxopt=-std=c++11 \
--cpu=armeabi-v7a
```
Replacing `armeabi-v7a` with your desired target architecture.
The library will be located at:
```
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
```
To build the Java counterpart:
```
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
```
You will find the JAR file at:
```
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
```
### CMake
For documentation on building a self-contained AAR file with cmake, see
[tensorflow/contrib/android/cmake](cmake).
### Makefile
For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
## AssetManagerFileSystem
This directory also contains a TensorFlow filesystem supporting the Android
asset manager. This may be useful when writing native (C++) code that is tightly
coupled with TensorFlow. For typical usage, the library above will be
sufficient.

View File

@ -1,272 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/android/asset_manager_filesystem.h"
#include <unistd.h>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system_helper.h"
namespace tensorflow {
namespace {
string RemoveSuffix(const string& name, const string& suffix) {
string output(name);
StringPiece piece(output);
absl::ConsumeSuffix(&piece, suffix);
return string(piece);
}
// Closes the given AAsset when variable is destructed.
class ScopedAsset {
public:
ScopedAsset(AAsset* asset) : asset_(asset) {}
~ScopedAsset() {
if (asset_ != nullptr) {
AAsset_close(asset_);
}
}
AAsset* get() const { return asset_; }
private:
AAsset* asset_;
};
// Closes the given AAssetDir when variable is destructed.
class ScopedAssetDir {
public:
ScopedAssetDir(AAssetDir* asset_dir) : asset_dir_(asset_dir) {}
~ScopedAssetDir() {
if (asset_dir_ != nullptr) {
AAssetDir_close(asset_dir_);
}
}
AAssetDir* get() const { return asset_dir_; }
private:
AAssetDir* asset_dir_;
};
class ReadOnlyMemoryRegionFromAsset : public ReadOnlyMemoryRegion {
public:
ReadOnlyMemoryRegionFromAsset(std::unique_ptr<char[]> data, uint64 length)
: data_(std::move(data)), length_(length) {}
~ReadOnlyMemoryRegionFromAsset() override = default;
const void* data() override { return reinterpret_cast<void*>(data_.get()); }
uint64 length() override { return length_; }
private:
std::unique_ptr<char[]> data_;
uint64 length_;
};
// Note that AAssets are not thread-safe and cannot be used across threads.
// However, AAssetManager is. Because RandomAccessFile must be thread-safe and
// used across threads, new AAssets must be created for every access.
// TODO(tylerrhodes): is there a more efficient way to do this?
class RandomAccessFileFromAsset : public RandomAccessFile {
public:
RandomAccessFileFromAsset(AAssetManager* asset_manager, const string& name)
: asset_manager_(asset_manager), file_name_(name) {}
~RandomAccessFileFromAsset() override = default;
Status Read(uint64 offset, size_t to_read, StringPiece* result,
char* scratch) const override {
auto asset = ScopedAsset(AAssetManager_open(
asset_manager_, file_name_.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", file_name_, " not found.");
}
off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET);
off64_t length = AAsset_getLength64(asset.get());
if (new_offset < 0) {
*result = StringPiece(scratch, 0);
return errors::OutOfRange("Read after file end.");
}
const off64_t region_left =
std::min(length - new_offset, static_cast<off64_t>(to_read));
int read = AAsset_read(asset.get(), scratch, region_left);
if (read < 0) {
return errors::Internal("Error reading from asset.");
}
*result = StringPiece(scratch, region_left);
return (region_left == to_read)
? Status::OK()
: errors::OutOfRange("Read less bytes than requested.");
}
private:
AAssetManager* asset_manager_;
string file_name_;
};
} // namespace
AssetManagerFileSystem::AssetManagerFileSystem(AAssetManager* asset_manager,
const string& prefix)
: asset_manager_(asset_manager), prefix_(prefix) {}
Status AssetManagerFileSystem::FileExists(const string& fname) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
return Status::OK();
}
Status AssetManagerFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
result->reset(new RandomAccessFileFromAsset(asset_manager_, path));
return Status::OK();
}
Status AssetManagerFileSystem::NewReadOnlyMemoryRegionFromFile(
const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_STREAMING));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
off64_t start, length;
int fd = AAsset_openFileDescriptor64(asset.get(), &start, &length);
std::unique_ptr<char[]> data;
if (fd >= 0) {
data.reset(new char[length]);
ssize_t result = pread(fd, data.get(), length, start);
if (result < 0) {
return errors::Internal("Error reading from file ", fname,
" using 'read': ", result);
}
if (result != length) {
return errors::Internal("Expected size does not match size read: ",
"Expected ", length, " vs. read ", result);
}
close(fd);
} else {
length = AAsset_getLength64(asset.get());
data.reset(new char[length]);
const void* asset_buffer = AAsset_getBuffer(asset.get());
if (asset_buffer == nullptr) {
return errors::Internal("Error reading ", fname, " from asset manager.");
}
memcpy(data.get(), asset_buffer, length);
}
result->reset(new ReadOnlyMemoryRegionFromAsset(std::move(data), length));
return Status::OK();
}
Status AssetManagerFileSystem::GetChildren(const string& prefixed_dir,
std::vector<string>* r) {
std::string path = NormalizeDirectoryPath(prefixed_dir);
auto dir =
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
if (dir.get() == nullptr) {
return errors::NotFound("Directory ", prefixed_dir, " not found.");
}
const char* next_file = AAssetDir_getNextFileName(dir.get());
while (next_file != nullptr) {
r->push_back(next_file);
next_file = AAssetDir_getNextFileName(dir.get());
}
return Status::OK();
}
Status AssetManagerFileSystem::GetFileSize(const string& fname, uint64* s) {
// If fname corresponds to a directory, return early. It doesn't map to an
// AAsset, and would otherwise return NotFound.
if (DirectoryExists(fname)) {
*s = 0;
return Status::OK();
}
string path = RemoveAssetPrefix(fname);
auto asset = ScopedAsset(
AAssetManager_open(asset_manager_, path.c_str(), AASSET_MODE_RANDOM));
if (asset.get() == nullptr) {
return errors::NotFound("File ", fname, " not found.");
}
*s = AAsset_getLength64(asset.get());
return Status::OK();
}
Status AssetManagerFileSystem::Stat(const string& fname, FileStatistics* stat) {
uint64 size;
stat->is_directory = DirectoryExists(fname);
TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
stat->length = size;
return Status::OK();
}
string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
return RemoveSuffix(RemoveAssetPrefix(fname), "/");
}
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
StringPiece piece(name);
absl::ConsumePrefix(&piece, prefix_);
return string(piece);
}
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
std::string path = NormalizeDirectoryPath(fname);
auto dir =
ScopedAssetDir(AAssetManager_openDir(asset_manager_, path.c_str()));
// Note that openDir will return something even if the directory doesn't
// exist. Therefore, we need to ensure one file exists in the folder.
return AAssetDir_getNextFileName(dir.get()) != NULL;
}
Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern,
std::vector<string>* results) {
return internal::GetMatchingPaths(this, Env::Default(), pattern, results);
}
Status AssetManagerFileSystem::NewWritableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::NewAppendableFile(
const string& fname, std::unique_ptr<WritableFile>* result) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::DeleteFile(const string& f) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::CreateDir(const string& d) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::DeleteDir(const string& d) {
return errors::Unimplemented("Asset storage is read only.");
}
Status AssetManagerFileSystem::RenameFile(const string& s, const string& t) {
return errors::Unimplemented("Asset storage is read only.");
}
} // namespace tensorflow

View File

@ -1,85 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
// FileSystem that uses Android's AAssetManager. Once initialized with a given
// AAssetManager, files in the given AAssetManager can be accessed through the
// prefix given when registered with the TensorFlow Env.
// Note that because APK assets are immutable, any operation that tries to
// modify the FileSystem will return tensorflow::error::code::UNIMPLEMENTED.
class AssetManagerFileSystem : public FileSystem {
public:
// Initialize an AssetManagerFileSystem. Note that this does not register the
// file system with TensorFlow.
// asset_manager - Non-null Android AAssetManager that backs this file
// system. The asset manager is not owned by this file system, and must
// outlive this class.
// prefix - Common prefix to strip from all file URIs before passing them to
// the asset_manager. This is required because TensorFlow gives the entire
// file URI (file:///my_dir/my_file.txt) and AssetManager only knows paths
// relative to its base directory.
AssetManagerFileSystem(AAssetManager* asset_manager, const string& prefix);
~AssetManagerFileSystem() override = default;
Status FileExists(const string& fname) override;
Status NewRandomAccessFile(
const string& filename,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile(
const string& filename,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status GetFileSize(const string& f, uint64* s) override;
// Currently just returns size.
Status Stat(const string& fname, FileStatistics* stat) override;
Status GetChildren(const string& dir, std::vector<string>* r) override;
// All these functions return Unimplemented error. Asset storage is
// read only.
Status NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const string& fname,
std::unique_ptr<WritableFile>* result) override;
Status DeleteFile(const string& f) override;
Status CreateDir(const string& d) override;
Status DeleteDir(const string& d) override;
Status RenameFile(const string& s, const string& t) override;
Status GetMatchingPaths(const string& pattern,
std::vector<string>* results) override;
private:
string RemoveAssetPrefix(const string& name);
// Return a string path that can be passed into AAssetManager functions.
// For example, 'my_prefix://some/dir/' would return 'some/dir'.
string NormalizeDirectoryPath(const string& fname);
bool DirectoryExists(const std::string& fname);
AAssetManager* asset_manager_;
string prefix_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_

View File

@ -1,80 +0,0 @@
#
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.4.1)
include(ExternalProject)
# TENSORFLOW_ROOT_DIR:
# root directory of tensorflow repo
# used for shared source files and pre-built libs
get_filename_component(TENSORFLOW_ROOT_DIR ../../../.. ABSOLUTE)
set(PREBUILT_DIR ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen)
add_library(lib_proto STATIC IMPORTED )
set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
${PREBUILT_DIR}/protobuf/lib/libprotobuf.a)
add_library(lib_nsync STATIC IMPORTED )
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
${TARGET_NSYNC_LIB}/lib/libnsync.a)
add_library(lib_tf STATIC IMPORTED )
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
${PREBUILT_DIR}/lib/libtensorflow-core.a)
# Change to compile flags should be replicated into bazel build file
# TODO: Consider options other than -O2 for binary size.
# e.g. -Os for gcc, and -Oz for clang.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
-std=c++11 -fno-rtti -fno-exceptions \
-O2 -Wno-narrowing -fomit-frame-pointer \
-mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \
-ftemplate-depth=900 \
-DGOOGLE_PROTOBUF_NO_RTTI \
-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} \
-Wl,--allow-multiple-definition \
-Wl,--whole-archive \
-fPIE -pie -v")
file(GLOB tensorflow_inference_sources
${CMAKE_CURRENT_SOURCE_DIR}/../jni/*.cc)
file(GLOB java_api_native_sources
${TENSORFLOW_ROOT_DIR}/tensorflow/java/src/main/native/*.cc)
add_library(tensorflow_inference SHARED
${tensorflow_inference_sources}
${TENSORFLOW_ROOT_DIR}/tensorflow/c/tf_status_helper.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/checkpoint_reader.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/test_op.cc
${TENSORFLOW_ROOT_DIR}/tensorflow/c/c_api.cc
${java_api_native_sources})
# Include libraries needed for hello-jni lib
target_link_libraries(tensorflow_inference
android
dl
log
m
z
lib_tf
lib_proto
lib_nsync)
include_directories(
${PREBUILT_DIR}/proto
${PREBUILT_DIR}/protobuf/include
${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen
${TENSORFLOW_ROOT_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/..)

View File

@ -1,48 +0,0 @@
TensorFlow-Android-Inference
============================
This directory contains CMake support for building the Android Java Inference
interface to the TensorFlow native APIs.
See [tensorflow/contrib/android](..) for more details about the library, and
instructions for building with Bazel.
Usage
-----
Add TensorFlow-Android-Inference as a dependency of your Android application
* settings.gradle
```
include ':TensorFlow-Android-Inference'
findProject(":TensorFlow-Android-Inference").projectDir =
new File("${/path/to/tensorflow_repo}/contrib/android/cmake")
```
* application's build.gradle (adding dependency):
```
debugCompile project(path: ':tensorflow_inference', configuration: 'debug')
releaseCompile project(path: ':tensorflow_inference', configuration: 'release')
```
Note: this makes native code in the lib traceable from your app.
Dependencies
------------
TensorFlow-Android-Inference depends on the TensorFlow static libs already built
in your local TensorFlow repo directory. For Linux/Mac OS, build_all_android.sh
is used in build.gradle to build it. It DOES take time to build the core libs;
so, by default, it is commented out to avoid confusion (otherwise
Android Studio would appear to hang during opening the project).
To enable it, refer to the comment in
* build.gradle
Output
------
- TensorFlow-Inference-debug.aar
- TensorFlow-Inference-release.aar
File libtensorflow_inference.so should be packed under jni/${ANDROID_ABI}/
in the above aar, and it is transparent to the app as it will access them via
equivalent java APIs.

View File

@ -1,105 +0,0 @@
apply plugin: 'com.android.library'
// TensorFlow repo root dir on local machine
def TF_SRC_DIR = projectDir.toString() + "/../../../.."
android {
compileSdkVersion 24
// Check local build_tools_version as this is liable to change within Android Studio.
buildToolsVersion '25.0.2'
// for debugging native code purpose
publishNonDefault true
defaultConfig {
archivesBaseName = "Tensorflow-Android-Inference"
minSdkVersion 21
targetSdkVersion 23
versionCode 1
versionName "1.0"
ndk {
abiFilters 'armeabi-v7a'
}
externalNativeBuild {
cmake {
arguments '-DANDROID_TOOLCHAIN=clang',
'-DANDROID_STL=c++_static'
}
}
}
sourceSets {
main {
java {
srcDir "${TF_SRC_DIR}/tensorflow/contrib/android/java"
srcDir "${TF_SRC_DIR}/tensorflow/java/src/main/java"
exclude '**/examples/**'
}
}
}
externalNativeBuild {
cmake {
path 'CMakeLists.txt'
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android.txt'),
'proguard-rules.pro'
}
}
}
// Build libtensorflow-core.a if necessary
// Note: the environment needs to be set up already
// [ such as installing autoconfig, make, etc ]
// How to use:
// 1) install all of the necessary tools to build libtensorflow-core.a
// 2) inside Android Studio IDE, uncomment buildTensorFlow in
// whenTaskAdded{...}
// 3) re-sync and re-build. It could take a long time if NOT building
// with multiple processes.
import org.apache.tools.ant.taskdefs.condition.Os
Properties properties = new Properties()
properties.load(project.rootProject.file('local.properties')
.newDataInputStream())
def ndkDir = properties.getProperty('ndk.dir')
if (ndkDir == null || ndkDir == "") {
ndkDir = System.getenv('ANDROID_NDK_HOME')
}
if (!Os.isFamily(Os.FAMILY_WINDOWS)) {
// This script is for non-Windows OS. For Windows OS, MANUALLY build
// (or copy the built) libs/headers to the
// ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/gen
// refer to CMakeLists.txt about lib and header directories for details
task buildTensorflow(type: Exec) {
group 'buildTensorflowLib'
workingDir getProjectDir().toString() + '/../../../../'
environment PATH: '/opt/local/bin:/opt/local/sbin:' +
System.getenv('PATH')
environment NDK_ROOT: ndkDir
commandLine 'tensorflow/contrib/makefile/build_all_android.sh'
}
tasks.whenTaskAdded { task ->
group 'buildTensorflowLib'
if (task.name.toLowerCase().contains('sources')) {
def tensorflowTarget = new File(getProjectDir().toString() +
'/../../makefile/gen/lib/libtensorflow-core.a')
if (!tensorflowTarget.exists()) {
// Note:
// just uncomment this line to use it:
// it can take long time to build by default
// it is disabled to avoid false first impression
task.dependsOn buildTensorflow
}
}
}
}
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
}

View File

@ -1,13 +0,0 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.contrib.android">
<uses-sdk
android:minSdkVersion="4"
android:targetSdkVersion="19" />
<application android:allowBackup="true" android:label="@string/app_name"
android:supportsRtl="true">
</application>
</manifest>

View File

@ -1,3 +0,0 @@
<resources>
<string name="app_name">TensorFlowInference</string>
</resources>

View File

@ -1,63 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.contrib.android;
/** Accumulate and analyze stats from metadata obtained from Session.Runner.run. */
public class RunStats implements AutoCloseable {
/**
* Options to be provided to a {@link org.tensorflow.Session.Runner} to enable stats accumulation.
*/
public static byte[] runOptions() {
return fullTraceRunOptions;
}
public RunStats() {
nativeHandle = allocate();
}
@Override
public void close() {
if (nativeHandle != 0) {
delete(nativeHandle);
}
nativeHandle = 0;
}
/** Accumulate stats obtained when executing a graph. */
public synchronized void add(byte[] runMetadata) {
add(nativeHandle, runMetadata);
}
/** Summary of the accumulated runtime stats. */
public synchronized String summary() {
return summary(nativeHandle);
}
private long nativeHandle;
// Hack: This is what a serialized RunOptions protocol buffer with trace_level: FULL_TRACE ends
// up as.
private static byte[] fullTraceRunOptions = new byte[] {0x08, 0x03};
private static native long allocate();
private static native void delete(long handle);
private static native void add(long handle, byte[] runMetadata);
private static native String summary(long handle);
}

View File

@ -1,650 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.contrib.android;
import android.content.res.AssetManager;
import android.os.Build.VERSION;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;
/**
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
* for inference.
*
* <p>See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java for an
* example usage.
*/
public class TensorFlowInferenceInterface {
private static final String TAG = "TensorFlowInferenceInterface";
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
/*
* Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
*
* @param assetManager The AssetManager to use to load the model file.
* @param model The filepath to the GraphDef proto representing the model.
*/
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
InputStream is = null;
try {
String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
is = assetManager.open(aname);
} catch (IOException e) {
if (hasAssetPrefix) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
// Perhaps the model file is not an asset but is on disk.
try {
is = new FileInputStream(model);
} catch (IOException e2) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
final int numBytesRead = is.read(graphDef);
if (numBytesRead != graphDef.length) {
throw new IOException(
"read error: read only "
+ numBytesRead
+ " of the graph, expected to read "
+ graphDef.length);
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
/*
* Load a TensorFlow model from provided InputStream.
* Note: The InputStream will not be closed after loading model, users need to
* close it themselves.
*
* @param is The InputStream to use to load the model.
*/
public TensorFlowInferenceInterface(InputStream is) {
prepareNativeRuntime();
// modelName is redundant for model loading from input stream, here is for
// avoiding error in initialization as modelName is marked final.
this.modelName = "";
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
try {
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
int numBytesRead;
byte[] buf = new byte[16384];
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
baos.write(buf, 0, numBytesRead);
}
byte[] graphDef = baos.toByteArray();
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
Log.i(TAG, "Successfully loaded model from the input stream");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from the input stream", e);
}
}
/*
* Construct a TensorFlowInferenceInterface with provided Graph
*
* @param g The Graph to use to construct this interface.
*/
public TensorFlowInferenceInterface(Graph g) {
prepareNativeRuntime();
// modelName is redundant here, here is for
// avoiding error in initialization as modelName is marked final.
this.modelName = "";
this.g = g;
this.sess = new Session(g);
this.runner = sess.runner();
}
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
* output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
*/
public void run(String[] outputNames) {
run(outputNames, false);
}
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
* output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
*/
public void run(String[] outputNames, boolean enableStats) {
run(outputNames, enableStats, new String[] {});
}
/** An overloaded version of runInference that allows supplying targetNodeNames as well */
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
// Release any Tensors from the previous run calls.
closeFetches();
// Add fetches.
for (String o : outputNames) {
fetchNames.add(o);
TensorId tid = TensorId.parse(o);
runner.fetch(tid.name, tid.outputIndex);
}
// Add targets.
for (String t : targetNodeNames) {
runner.addTarget(t);
}
// Run the session.
try {
if (enableStats) {
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
fetchTensors = r.outputs;
if (runStats == null) {
runStats = new RunStats();
}
runStats.add(r.metadata);
} else {
fetchTensors = runner.run();
}
} catch (RuntimeException e) {
// Ideally the exception would have been let through, but since this interface predates the
// TensorFlow Java API, must return -1.
Log.e(
TAG,
"Failed to run TensorFlow inference with inputs:["
+ TextUtils.join(", ", feedNames)
+ "], outputs:["
+ TextUtils.join(", ", fetchNames)
+ "]");
throw e;
} finally {
// Always release the feeds (to save resources) and reset the runner, this run is
// over.
closeFeeds();
runner = sess.runner();
}
}
/** Returns a reference to the Graph describing the computation run during inference. */
public Graph graph() {
return g;
}
public Operation graphOperation(String operationName) {
final Operation operation = g.operation(operationName);
if (operation == null) {
throw new RuntimeException(
"Node '" + operationName + "' does not exist in model '" + modelName + "'");
}
return operation;
}
/** Returns the last stat summary string if logging is enabled. */
public String getStatString() {
return (runStats == null) ? "" : runStats.summary();
}
/**
* Cleans up the state associated with this Object.
*
* <p>The TenosrFlowInferenceInterface object is no longer usable after this method returns.
*/
public void close() {
closeFeeds();
closeFetches();
sess.close();
g.close();
if (runStats != null) {
runStats.close();
}
runStats = null;
}
@Override
protected void finalize() throws Throwable {
try {
close();
} finally {
super.finalize();
}
}
// Methods for taking a native Tensor and filling it with values from Java arrays.
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, boolean[] src, long... dims) {
byte[] b = new byte[src.length];
for (int i = 0; i < src.length; i++) {
b[i] = src[i] ? (byte) 1 : (byte) 0;
}
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, float[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, int[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, long[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, double[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
}
/**
* Given a source array with shape {@link dims} and content {@link src}, copy the contents into
* the input Tensor with name {@link inputName}. The source array {@link src} must have at least
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, byte[] src, long... dims) {
addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
}
/**
* Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued
* scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of bytes, not
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
addFeed(inputName, Tensors.create(src));
}
/**
* Copy an array of byte sequences into the input Tensor with name {@link inputName} as a
* string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" is an
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
addFeed(inputName, Tensors.create(src));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, FloatBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, IntBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, LongBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, DoubleBuffer src, long... dims) {
addFeed(inputName, Tensor.create(dims, src));
}
/**
* Given a source buffer with shape {@link dims} and content {@link src}, both stored as
* <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input
* Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, ByteBuffer src, long... dims) {
addFeed(inputName, Tensor.create(UInt8.class, dims, src));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, float[] dst) {
fetch(outputName, FloatBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, int[] dst) {
fetch(outputName, IntBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, long[] dst) {
fetch(outputName, LongBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, double[] dst) {
fetch(outputName, DoubleBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
*/
public void fetch(String outputName, byte[] dst) {
fetch(outputName, ByteBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, FloatBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, IntBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, LongBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, DoubleBuffer dst) {
getTensor(outputName).writeTo(dst);
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
*/
public void fetch(String outputName, ByteBuffer dst) {
getTensor(outputName).writeTo(dst);
}
private void prepareNativeRuntime() {
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
try {
// Hack to see if the native libraries have been loaded.
new RunStats();
Log.i(TAG, "TensorFlow native methods already loaded");
} catch (UnsatisfiedLinkError e1) {
Log.i(
TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
try {
System.loadLibrary("tensorflow_inference");
Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
} catch (UnsatisfiedLinkError e2) {
throw new RuntimeException(
"Native TF methods not found; check that the correct native"
+ " libraries are present in the APK.");
}
}
}
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("importGraphDef");
}
try {
g.importGraphDef(graphDef);
} catch (IllegalArgumentException e) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // importGraphDef.
}
final long endMs = System.currentTimeMillis();
Log.i(
TAG,
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
private void addFeed(String inputName, Tensor<?> t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
runner.feed(tid.name, tid.outputIndex, t);
feedNames.add(inputName);
feedTensors.add(t);
}
private static class TensorId {
String name;
int outputIndex;
// Parse output names into a TensorId.
//
// E.g., "foo" --> ("foo", 0), while "foo:1" --> ("foo", 1)
public static TensorId parse(String name) {
TensorId tid = new TensorId();
int colonIndex = name.lastIndexOf(':');
if (colonIndex < 0) {
tid.outputIndex = 0;
tid.name = name;
return tid;
}
try {
tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
tid.name = name.substring(0, colonIndex);
} catch (NumberFormatException e) {
tid.outputIndex = 0;
tid.name = name;
}
return tid;
}
}
private Tensor<?> getTensor(String outputName) {
int i = 0;
for (String n : fetchNames) {
if (n.equals(outputName)) {
return fetchTensors.get(i);
}
++i;
}
throw new RuntimeException(
"Node '" + outputName + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
for (Tensor<?> t : feedTensors) {
t.close();
}
feedTensors.clear();
feedNames.clear();
}
private void closeFetches() {
for (Tensor<?> t : fetchTensors) {
t.close();
}
fetchTensors.clear();
fetchNames.clear();
}
// Immutable state.
private final String modelName;
private final Graph g;
private final Session sess;
// State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
private List<String> fetchNames = new ArrayList<String>();
private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
// Mutable state.
private RunStats runStats;
}

View File

@ -1,83 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/android/jni/run_stats_jni.h"
#include <jni.h>
#include <sstream>
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/stat_summarizer.h"
using tensorflow::RunMetadata;
using tensorflow::StatSummarizer;
namespace {
StatSummarizer* requireHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
"close() has been called on the RunStats object");
return nullptr;
}
return reinterpret_cast<StatSummarizer*>(handle);
}
} // namespace
#define RUN_STATS_METHOD(name) \
JNICALL Java_org_tensorflow_contrib_android_RunStats_##name
JNIEXPORT jlong RUN_STATS_METHOD(allocate)(JNIEnv* env, jclass clazz) {
static_assert(sizeof(jlong) >= sizeof(StatSummarizer*),
"Cannot package C++ object pointers as a Java long");
tensorflow::StatSummarizerOptions opts;
return reinterpret_cast<jlong>(new StatSummarizer(opts));
}
JNIEXPORT void RUN_STATS_METHOD(delete)(JNIEnv* env, jclass clazz,
jlong handle) {
if (handle == 0) return;
delete reinterpret_cast<StatSummarizer*>(handle);
}
JNIEXPORT void RUN_STATS_METHOD(add)(JNIEnv* env, jclass clazz, jlong handle,
jbyteArray run_metadata) {
StatSummarizer* s = requireHandle(env, handle);
if (s == nullptr) return;
jbyte* data = env->GetByteArrayElements(run_metadata, nullptr);
int size = static_cast<int>(env->GetArrayLength(run_metadata));
tensorflow::RunMetadata proto;
if (!proto.ParseFromArray(data, size)) {
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"),
"runMetadata does not seem to be a serialized RunMetadata "
"protocol message");
} else if (proto.has_step_stats()) {
s->ProcessStepStats(proto.step_stats());
}
env->ReleaseByteArrayElements(run_metadata, data, JNI_ABORT);
}
JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz,
jlong handle) {
StatSummarizer* s = requireHandle(env, handle);
if (s == nullptr) return nullptr;
std::stringstream ret;
ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME,
10)
<< s->GetStatsByNodeType() << s->ShortSummary();
return env->NewStringUTF(ret.str().c_str());
}
#undef RUN_STATS_METHOD

View File

@ -1,40 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
#define ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#define RUN_STATS_METHOD(name) \
Java_org_tensorflow_contrib_android_RunStats_##name
JNIEXPORT JNICALL jlong RUN_STATS_METHOD(allocate)(JNIEnv*, jclass);
JNIEXPORT JNICALL void RUN_STATS_METHOD(delete)(JNIEnv*, jclass, jlong);
JNIEXPORT JNICALL void RUN_STATS_METHOD(add)(JNIEnv*, jclass, jlong,
jbyteArray);
JNIEXPORT JNICALL jstring RUN_STATS_METHOD(summary)(JNIEnv*, jclass, jlong);
#undef RUN_STATS_METHOD
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // ORG_TENSORFLOW_JNI_RUN_STATS_JNI_H_

View File

@ -1,11 +0,0 @@
VERS_1.0 {
# Export JNI symbols.
global:
Java_*;
JNI_OnLoad;
JNI_OnUnload;
# Hide everything else.
local:
*;
};

View File

@ -1,29 +0,0 @@
package(
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
py_library(
name = "autograph",
srcs = [
"__init__.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
# This module is kept for backward compatibility only. To depend on AutoGraph,
# use //third_party/tensorflow/python/autograph instead.
deps = [
"//tensorflow/python/autograph",
],
)

View File

@ -1,9 +0,0 @@
# AutoGraph
**NOTE: As tensorflow.contrib is being
[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
moving into TensorFlow core.
The new code location is `tensorflow/python/autograph`. Please refer to the
README.md file in that directory.
**

View File

@ -1,24 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This is the legacy module for AutoGraph, kept for backward compatibility.
New users should instead use `tensorflow.python.autograph`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph import * # pylint:disable=wildcard-import

View File

@ -1,39 +0,0 @@
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
package(
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "benchmark_base",
srcs = [
"benchmark_base.py",
],
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "cartpole_benchmark",
size = "enormous",
srcs = ["cartpole_benchmark.py"],
python_version = "PY2",
tags = [
"local",
"manual",
"no_oss",
"notap",
"nozapfhahn",
],
deps = [
":benchmark_base",
# Note: required gym dependency may need to be added here.
],
)
tf_py_logged_benchmark(
name = "cartpole_logged_benchmark",
target = "//tensorflow/contrib/autograph/examples/benchmarks:cartpole_benchmark",
)

View File

@ -1,62 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common benchmarking code.
See https://www.tensorflow.org/community/benchmarks for usage.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import tensorflow as tf
class ReportingBenchmark(tf.test.Benchmark):
"""Base class for a benchmark that reports general performance metrics.
Subclasses only need to call one of the _profile methods, and optionally
report_results.
"""
def time_execution(self, name, target, iters, warm_up_iters=5):
for _ in range(warm_up_iters):
target()
all_times = []
for _ in range(iters):
iter_time = time.time()
target()
all_times.append(time.time() - iter_time)
avg_time = np.average(all_times)
extras = {}
extras['all_times'] = all_times
if isinstance(name, tuple):
extras['name'] = name
name = '_'.join(str(piece) for piece in name)
self.report_benchmark(
iters=iters, wall_time=avg_time, name=name, extras=extras)
if __name__ == '__main__':
tf.test.main()

View File

@ -1,492 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A basic RL cartpole benchmark.
The RL model uses the OpenAI Gym environment to train a simple network using
the policy gradients method. The training scales the gradients for each step
by the episode's cumulative discounted reward and averages these gradients over
a fixed number of games before applying the optimization step.
For benchmarking purposes, we replace the OpenAI Gym environment to a fake
that returns random actions and rewards and never ends the episode. This way
the benchmarks compare the same amount of computation at each step.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import numpy as np
import tensorflow as tf
from tensorflow.contrib import eager
from tensorflow.contrib.autograph.examples.benchmarks import benchmark_base
from tensorflow.python import autograph as ag
from tensorflow.python.eager import context
#
# AutoGraph implementation
#
@ag.convert()
def graph_append_discounted_rewards(destination, rewards, discount_rate):
"""Discounts episode rewards and appends them to destination."""
ag.set_element_type(rewards, tf.float32)
cdr = 0.0
reverse_discounted = []
ag.set_element_type(reverse_discounted, tf.float32)
for i in range(len(rewards) - 1, -1, -1):
cdr = cdr * discount_rate + rewards[i]
cdr.set_shape(())
reverse_discounted.append(cdr)
retval = destination
# Note: AutoGraph doesn't yet support reversed() so we use a loop instead.
for i in range(len(reverse_discounted) - 1, -1, -1):
retval.append(reverse_discounted[i])
return retval
class GraphPolicyNetwork(tf.keras.Model):
"""Policy network for the cart-pole reinforcement learning problem.
The forward path of the network takes an observation from the cart-pole
environment (length-4 vector) and outputs an action.
"""
def __init__(self, hidden_size):
super(GraphPolicyNetwork, self).__init__()
self._hidden_layer = tf.keras.layers.Dense(
hidden_size, activation=tf.nn.elu)
self._output_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
"""Calculates logits and action.
Args:
inputs: Observations from a step in the cart-pole environment, of shape
`(batch_size, input_size)`
Returns:
logits: the logits output by the output layer. This can be viewed as the
likelihood vales of choosing the left (0) action. Shape:
`(batch_size, 1)`.
actions: randomly selected actions ({0, 1}) based on the logits. Shape:
`(batch_size, 1)`.
"""
hidden = self._hidden_layer(inputs)
logits = self._output_layer(hidden)
left_prob = tf.nn.sigmoid(logits)
action_probs = tf.concat([left_prob, 1.0 - left_prob], 1)
actions = tf.multinomial(tf.log(action_probs), 1)
return logits, actions
# TODO(mdan): Move this method out of the class.
@ag.convert()
def train(self, cart_pole_env, optimizer, discount_rate, num_games,
max_steps_per_game):
var_list = tf.trainable_variables()
grad_list = [
tf.TensorArray(tf.float32, 0, dynamic_size=True) for _ in var_list
]
step_counts = []
discounted_rewards = []
ag.set_element_type(discounted_rewards, tf.float32)
ag.set_element_type(step_counts, tf.int32)
# Note: we use a shared object, cart_pole_env here. Because calls to the
# object's method are made through py_func, TensorFlow cannot detect its
# data dependencies. Hence we must manually synchronize access to it
# and ensure the control dependencies are set in such a way that
# calls to reset(), take_one_step, etc. are made in the correct order.
sync_counter = tf.constant(0)
for _ in tf.range(num_games):
with tf.control_dependencies([sync_counter]):
obs = cart_pole_env.reset()
with tf.control_dependencies([obs]):
sync_counter += 1
game_rewards = []
ag.set_element_type(game_rewards, tf.float32)
for step in tf.range(max_steps_per_game):
logits, actions = self(obs) # pylint:disable=not-callable
logits = tf.reshape(logits, ())
actions = tf.reshape(actions, ())
labels = 1.0 - tf.cast(actions, tf.float32)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits)
grads = tf.gradients(loss, var_list)
for i in range(len(grads)):
grad_list[i].append(grads[i])
with tf.control_dependencies([sync_counter]):
obs, reward, done = cart_pole_env.step(actions)
with tf.control_dependencies([obs]):
sync_counter += 1
obs = tf.reshape(obs, (1, 4))
game_rewards.append(reward)
if reward < 0.1 or done:
step_counts.append(step + 1)
break
discounted_rewards = graph_append_discounted_rewards(
discounted_rewards, game_rewards, discount_rate)
discounted_rewards = ag.stack(discounted_rewards)
discounted_rewards.set_shape((None,))
mean, variance = tf.nn.moments(discounted_rewards, [0])
normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance)
for i in range(len(grad_list)):
g = ag.stack(grad_list[i])
# This block just adjusts the shapes to match for multiplication.
r = normalized_rewards
if r.shape.ndims < g.shape.ndims:
r = tf.expand_dims(r, -1)
if r.shape.ndims < g.shape.ndims:
r = tf.expand_dims(r, -1)
grad_list[i] = tf.reduce_mean(g * r, axis=0)
optimizer.apply_gradients(
zip(grad_list, var_list), global_step=tf.train.get_global_step())
return ag.stack(step_counts)
@ag.convert()
def graph_train_model(policy_network, cart_pole_env, optimizer, iterations):
"""Trains the policy network for a given number of iterations."""
i = tf.constant(0)
mean_steps_per_iteration = []
ag.set_element_type(mean_steps_per_iteration, tf.int32)
while i < iterations:
steps_per_game = policy_network.train(
cart_pole_env,
optimizer,
discount_rate=0.95,
num_games=20,
max_steps_per_game=200)
mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game))
i += 1
return ag.stack(mean_steps_per_iteration)
class GraphGymCartpoleEnv(object):
"""An env backed by OpenAI Gym's CartPole environment.
Used to confirm a functional model only.
"""
def __init__(self):
cart_pole_env = gym.make('CartPole-v1')
cart_pole_env.seed(0)
cart_pole_env.reset()
self.env = cart_pole_env
def reset(self):
obs = ag.utils.wrap_py_func(self.env.reset, tf.float64, ())
obs = tf.reshape(obs, (1, 4))
obs = tf.cast(obs, tf.float32)
return obs
def step(self, actions):
def take_one_step(actions):
obs, reward, done, _ = self.env.step(actions)
obs = obs.astype(np.float32)
reward = np.float32(reward)
return obs, reward, done
return ag.utils.wrap_py_func(take_one_step,
(tf.float32, tf.float32, tf.bool), (actions,))
class GraphRandomCartpoleEnv(object):
"""An environment that returns random actions and never finishes.
Used during benchmarking, it will cause training to run a constant number of
steps.
"""
def reset(self):
return tf.random.normal((1, 4))
def step(self, actions):
with tf.control_dependencies([actions]):
random_obs = tf.random.normal((1, 4))
fixed_reward = tf.constant(0.001)
done = tf.constant(False)
return random_obs, fixed_reward, done
#
# Eager implementation
#
def eager_append_discounted_rewards(discounted_rewards, rewards, discount_rate):
cdr = 0.0
reverse_discounted = []
for i in range(len(rewards) - 1, -1, -1):
cdr = cdr * discount_rate + rewards[i]
reverse_discounted.append(cdr)
discounted_rewards.extend(reversed(reverse_discounted))
return discounted_rewards
class EagerPolicyNetwork(tf.keras.Model):
"""Policy network for the cart-pole reinforcement learning problem.
The forward path of the network takes an observation from the cart-pole
environment (length-4 vector) and outputs an action.
"""
def __init__(self, hidden_size):
super(EagerPolicyNetwork, self).__init__()
self._hidden_layer = tf.keras.layers.Dense(
hidden_size, activation=tf.nn.elu)
self._output_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
"""Calculates logits and action.
Args:
inputs: Observations from a step in the cart-pole environment, of shape
`(batch_size, input_size)`
Returns:
logits: the logits output by the output layer. This can be viewed as the
likelihood vales of choosing the left (0) action. Shape:
`(batch_size, 1)`.
actions: randomly selected actions ({0, 1}) based on the logits. Shape:
`(batch_size, 1)`.
"""
hidden = self._hidden_layer(inputs)
logits = self._output_layer(hidden)
left_prob = tf.nn.sigmoid(logits)
action_probs = tf.concat([left_prob, 1.0 - left_prob], 1)
self._grad_fn = eager.implicit_gradients(
self._get_cross_entropy_and_save_actions)
actions = tf.multinomial(tf.log(action_probs), 1)
return logits, actions
def _get_cross_entropy_and_save_actions(self, inputs):
logits, actions = self(inputs) # pylint:disable=not-callable
self._current_actions = actions
labels = 1.0 - tf.cast(actions, tf.float32)
return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
def train(self, cart_pole_env, optimizer, discount_rate, num_games,
max_steps_per_game):
grad_list = None
step_counts = []
discounted_rewards = []
for _ in range(num_games):
obs = cart_pole_env.reset()
game_rewards = []
for step in range(max_steps_per_game):
grads_and_vars = self._grad_fn(tf.constant([obs], dtype=tf.float32))
grads, var_list = zip(*grads_and_vars)
actions = self._current_actions.numpy()[0][0]
if grad_list is None:
grad_list = [[g] for g in grads]
else:
for i in range(len(grads)):
grad_list[i].append(grads[i])
obs, reward, done = cart_pole_env.step(actions)
game_rewards.append(reward)
if reward < 0.1 or done:
step_counts.append(step + 1)
break
discounted_rewards = eager_append_discounted_rewards(
discounted_rewards, game_rewards, discount_rate)
discounted_rewards = tf.stack(discounted_rewards)
mean, variance = tf.nn.moments(discounted_rewards, [0])
normalized_rewards = (discounted_rewards - mean) / tf.sqrt(variance)
for i in range(len(grad_list)):
g = tf.stack(grad_list[i])
r = normalized_rewards
while r.shape.ndims < g.shape.ndims:
r = tf.expand_dims(r, -1)
grad_list[i] = tf.reduce_mean(g * r, axis=0)
optimizer.apply_gradients(
zip(grad_list, var_list), global_step=tf.train.get_global_step())
return tf.stack(step_counts)
def eager_train_model(policy_network, cart_pole_env, optimizer, iterations):
"""Trains the policy network for a given number of iterations."""
mean_steps_per_iteration = []
for _ in range(iterations):
steps_per_game = policy_network.train(
cart_pole_env,
optimizer,
discount_rate=0.95,
num_games=20,
max_steps_per_game=200)
mean_steps_per_iteration.append(tf.reduce_mean(steps_per_game))
return mean_steps_per_iteration
class EagerGymCartpoleEnv(object):
"""An env backed by OpenAI Gym's CartPole environment.
Used to confirm a functional model only.
"""
def __init__(self):
cart_pole_env = gym.make('CartPole-v1')
cart_pole_env.seed(0)
cart_pole_env.reset()
self.env = cart_pole_env
def reset(self):
return self.env.reset()
def step(self, actions):
obs, reward, done, _ = self.env.step(actions)
return obs, reward, done
class EagerRandomCartpoleEnv(object):
"""An environment that returns random actions and never finishes.
Used during benchmarking, it will cause training to run a constant number of
steps.
"""
def reset(self):
return np.random.normal(size=(4,))
def step(self, actions):
with tf.control_dependencies([actions]):
random_obs = np.random.normal(size=(4,))
fixed_reward = 0.001
done = False
return random_obs, fixed_reward, done
def graph_demo_training():
"""Not used in the benchmark. Used to confirm a functional model."""
with tf.Graph().as_default():
tf.set_random_seed(0)
network = GraphPolicyNetwork(hidden_size=5)
network.build((1, 4))
env = GraphGymCartpoleEnv()
opt = tf.train.AdamOptimizer(0.05)
train_ops = graph_train_model(network, env, opt, iterations=5)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
steps_per_iteration = sess.run(train_ops)
for i, steps in enumerate(steps_per_iteration):
print('Step {} iterations: {}'.format(i, steps))
def eager_demo_training():
with context.eager_mode():
network = EagerPolicyNetwork(hidden_size=5)
network.build((1, 4))
env = EagerGymCartpoleEnv()
opt = tf.train.AdamOptimizer(0.05)
steps_per_iteration = eager_train_model(network, env, opt, iterations=5)
for i, steps in enumerate(steps_per_iteration):
print('Step {} iterations: {}'.format(i, steps))
class RLCartPoleBenchmark(benchmark_base.ReportingBenchmark):
"""Actual benchmark.
Trains the RL agent a fixed number of times, on random environments that
result in constant number of steps.
"""
def benchmark_cartpole(self):
def train_session(sess, ops):
return lambda: sess.run(ops)
def train_eager(network, env, opt):
return lambda: eager_train_model(network, env, opt, iterations=10)
for model_size in (10, 100, 1000):
with tf.Graph().as_default():
network = GraphPolicyNetwork(hidden_size=model_size)
network.build((1, 4))
env = GraphRandomCartpoleEnv()
opt = tf.train.AdamOptimizer(0.05)
train_ops = graph_train_model(network, env, opt, iterations=10)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
self.time_execution(('cartpole', 'autograph', model_size),
train_session(sess, train_ops), 20)
with context.eager_mode():
network = EagerPolicyNetwork(hidden_size=model_size)
network.build((1, 4))
env = EagerRandomCartpoleEnv()
opt = tf.train.AdamOptimizer(0.05)
self.time_execution(('cartpole', 'eager', model_size),
train_eager(network, env, opt), 20)
if __name__ == '__main__':
tf.test.main()

File diff suppressed because one or more lines are too long

View File

@ -1,652 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "etTmZVFN8fYO"
},
"source": [
"This notebook runs a basic speed test for a short training loop of a neural network training on the MNIST dataset."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eqOvRhOz8SWs"
},
"source": [
"### Imports"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "nHY0tntRizGb"
},
"outputs": [],
"source": [
"!pip install -U -q tf-nightly"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "Pa2qpEmoVOGe"
},
"outputs": [],
"source": [
"import gzip\n",
"import os\n",
"import shutil\n",
"import time\n",
"\n",
"import numpy as np\n",
"import six\n",
"from six.moves import urllib\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow.contrib import autograph as ag\n",
"from tensorflow.contrib.eager.python import tfe\n",
"from tensorflow.python.eager import context\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PZWxEJFM9A7b"
},
"source": [
"### Testing boilerplate"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "kfZk9EFZ5TeQ"
},
"outputs": [],
"source": [
"# Test-only parameters. Test checks successful completion not correctness. \n",
"burn_ins = 1\n",
"trials = 1\n",
"max_steps = 2\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "k0GKbZBJ9Gt9"
},
"source": [
"### Speed test configuration"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "gWXV8WHn43iZ"
},
"outputs": [],
"source": [
"#@test {\"skip\": true} \n",
"burn_ins = 3\n",
"trials = 10\n",
"max_steps = 500\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kZV_3pGy8033"
},
"source": [
"### Data source setup"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "YfnHJbBOBKae"
},
"outputs": [],
"source": [
"def download(directory, filename):\n",
" filepath = os.path.join(directory, filename)\n",
" if tf.gfile.Exists(filepath):\n",
" return filepath\n",
" if not tf.gfile.Exists(directory):\n",
" tf.gfile.MakeDirs(directory)\n",
" url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
" zipped_filepath = filepath + '.gz'\n",
" print('Downloading %s to %s' % (url, zipped_filepath))\n",
" urllib.request.urlretrieve(url, zipped_filepath)\n",
" with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
" shutil.copyfileobj(f_in, f_out)\n",
" os.remove(zipped_filepath)\n",
" return filepath\n",
"\n",
"\n",
"def dataset(directory, images_file, labels_file):\n",
" images_file = download(directory, images_file)\n",
" labels_file = download(directory, labels_file)\n",
"\n",
" def decode_image(image):\n",
" # Normalize from [0, 255] to [0.0, 1.0]\n",
" image = tf.decode_raw(image, tf.uint8)\n",
" image = tf.cast(image, tf.float32)\n",
" image = tf.reshape(image, [784])\n",
" return image / 255.0\n",
"\n",
" def decode_label(label):\n",
" label = tf.decode_raw(label, tf.uint8)\n",
" label = tf.reshape(label, [])\n",
" return tf.to_int32(label)\n",
"\n",
" images = tf.data.FixedLengthRecordDataset(\n",
" images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
" labels = tf.data.FixedLengthRecordDataset(\n",
" labels_file, 1, header_bytes=8).map(decode_label)\n",
" return tf.data.Dataset.zip((images, labels))\n",
"\n",
"\n",
"def mnist_train(directory):\n",
" return dataset(directory, 'train-images-idx3-ubyte',\n",
" 'train-labels-idx1-ubyte')\n",
"\n",
"def mnist_test(directory):\n",
" return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n",
"\n",
"def setup_mnist_data(is_training, hp, batch_size):\n",
" if is_training:\n",
" ds = mnist_train('/tmp/autograph_mnist_data')\n",
" ds = ds.cache()\n",
" ds = ds.shuffle(batch_size * 10)\n",
" else:\n",
" ds = mnist_test('/tmp/autograph_mnist_data')\n",
" ds = ds.cache()\n",
" ds = ds.repeat()\n",
" ds = ds.batch(batch_size)\n",
" return ds\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qzkZyZcS9THu"
},
"source": [
"### Keras model definition"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "x_MU13boiok2"
},
"outputs": [],
"source": [
"def mlp_model(input_shape):\n",
" model = tf.keras.Sequential((\n",
" tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
" tf.keras.layers.Dense(100, activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')))\n",
" model.build()\n",
" return model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DXt4GoTxtvn2"
},
"source": [
"# AutoGraph"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "W51sfbONiz_5"
},
"outputs": [],
"source": [
"def predict(m, x, y):\n",
" y_p = m(x)\n",
" losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
" l = tf.reduce_mean(losses)\n",
" accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
" accuracy = tf.reduce_mean(accuracies)\n",
" return l, accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "CsAD0ajbi9iZ"
},
"outputs": [],
"source": [
"def fit(m, x, y, opt):\n",
" l, accuracy = predict(m, x, y)\n",
" opt.minimize(l)\n",
" return l, accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "RVw57HdTjPzi"
},
"outputs": [],
"source": [
"def get_next_batch(ds):\n",
" itr = ds.make_one_shot_iterator()\n",
" image, label = itr.get_next()\n",
" x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
" y = tf.one_hot(tf.squeeze(label), 10)\n",
" return x, y\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "UUI0566FjZPx"
},
"outputs": [],
"source": [
"def train(train_ds, test_ds, hp):\n",
" m = mlp_model((28 * 28,))\n",
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
"\n",
" train_losses = []\n",
" test_losses = []\n",
" train_accuracies = []\n",
" test_accuracies = []\n",
" ag.set_element_type(train_losses, tf.float32)\n",
" ag.set_element_type(test_losses, tf.float32)\n",
" ag.set_element_type(train_accuracies, tf.float32)\n",
" ag.set_element_type(test_accuracies, tf.float32)\n",
"\n",
" i = tf.constant(0)\n",
" while i \u003c hp.max_steps:\n",
" train_x, train_y = get_next_batch(train_ds)\n",
" test_x, test_y = get_next_batch(test_ds)\n",
" step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
" step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
"\n",
" train_losses.append(step_train_loss)\n",
" test_losses.append(step_test_loss)\n",
" train_accuracies.append(step_train_accuracy)\n",
" test_accuracies.append(step_test_accuracy)\n",
"\n",
" i += 1\n",
" return (ag.stack(train_losses), ag.stack(test_losses),\n",
" ag.stack(train_accuracies), ag.stack(test_accuracies))\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"height": 215
},
"colab_type": "code",
"executionInfo": {
"elapsed": 12156,
"status": "ok",
"timestamp": 1531752050611,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "K1m8TwOKjdNd",
"outputId": "bd5746f2-bf91-44aa-9eff-38eb11ced33f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Duration:', 0.6226680278778076)\n",
"('Duration:', 0.6082069873809814)\n",
"('Duration:', 0.6223258972167969)\n",
"('Duration:', 0.6176440715789795)\n",
"('Duration:', 0.6309840679168701)\n",
"('Duration:', 0.6180410385131836)\n",
"('Duration:', 0.6219630241394043)\n",
"('Duration:', 0.6183009147644043)\n",
"('Duration:', 0.6176400184631348)\n",
"('Duration:', 0.6476900577545166)\n",
"('Mean duration:', 0.62254641056060789, '+/-', 0.0099792188690656976)\n"
]
}
],
"source": [
"#@test {\"timeout\": 90}\n",
"with tf.Graph().as_default():\n",
" hp = tf.contrib.training.HParams(\n",
" learning_rate=0.05,\n",
" max_steps=max_steps,\n",
" )\n",
" train_ds = setup_mnist_data(True, hp, 500)\n",
" test_ds = setup_mnist_data(False, hp, 100)\n",
" tf_train = ag.to_graph(train)\n",
" losses = tf_train(train_ds, test_ds, hp)\n",
"\n",
" with tf.Session() as sess:\n",
" durations = []\n",
" for t in range(burn_ins + trials):\n",
" sess.run(tf.global_variables_initializer())\n",
"\n",
" start = time.time()\n",
" (train_losses, test_losses, train_accuracies,\n",
" test_accuracies) = sess.run(losses)\n",
"\n",
" if t \u003c burn_ins:\n",
" continue\n",
"\n",
" duration = time.time() - start\n",
" durations.append(duration)\n",
" print('Duration:', duration)\n",
"\n",
" print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "A06kdgtZtlce"
},
"source": [
"# Eager"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "hBKOKGrWty4e"
},
"outputs": [],
"source": [
"def predict(m, x, y):\n",
" y_p = m(x)\n",
" losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n",
" l = tf.reduce_mean(losses)\n",
" accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
" accuracy = tf.reduce_mean(accuracies)\n",
" return l, accuracy\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "HCgTZ0MTt6vt"
},
"outputs": [],
"source": [
"def train(ds, hp):\n",
" m = mlp_model((28 * 28,))\n",
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
"\n",
" train_losses = []\n",
" test_losses = []\n",
" train_accuracies = []\n",
" test_accuracies = []\n",
"\n",
" i = 0\n",
" train_test_itr = tfe.Iterator(ds)\n",
" for (train_x, train_y), (test_x, test_y) in train_test_itr:\n",
" train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n",
" train_y = tf.one_hot(tf.squeeze(train_y), 10)\n",
" test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n",
" test_y = tf.one_hot(tf.squeeze(test_y), 10)\n",
"\n",
" if i \u003e hp.max_steps:\n",
" break\n",
"\n",
" with tf.GradientTape() as tape:\n",
" step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n",
" grad = tape.gradient(step_train_loss, m.variables)\n",
" opt.apply_gradients(zip(grad, m.variables))\n",
" step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
"\n",
" train_losses.append(step_train_loss)\n",
" test_losses.append(step_test_loss)\n",
" train_accuracies.append(step_train_accuracy)\n",
" test_accuracies.append(step_test_accuracy)\n",
"\n",
" i += 1\n",
" return train_losses, test_losses, train_accuracies, test_accuracies\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"height": 215
},
"colab_type": "code",
"executionInfo": {
"elapsed": 52499,
"status": "ok",
"timestamp": 1531752103279,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "plv_yrn_t8Dy",
"outputId": "55d5ab3d-252d-48ba-8fb4-20ec3c3e6d00"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Duration:', 3.9973549842834473)\n",
"('Duration:', 4.018772125244141)\n",
"('Duration:', 3.9740989208221436)\n",
"('Duration:', 3.9922947883605957)\n",
"('Duration:', 3.9795801639556885)\n",
"('Duration:', 3.966722011566162)\n",
"('Duration:', 3.986541986465454)\n",
"('Duration:', 3.992305040359497)\n",
"('Duration:', 4.012261867523193)\n",
"('Duration:', 4.004716157913208)\n",
"('Mean duration:', 3.9924648046493529, '+/-', 0.015681688635624851)\n"
]
}
],
"source": [
"#@test {\"timeout\": 90}\n",
"with context.eager_mode():\n",
" durations = []\n",
" for t in range(burn_ins + trials):\n",
" hp = tf.contrib.training.HParams(\n",
" learning_rate=0.05,\n",
" max_steps=max_steps,\n",
" )\n",
" train_ds = setup_mnist_data(True, hp, 500)\n",
" test_ds = setup_mnist_data(False, hp, 100)\n",
" ds = tf.data.Dataset.zip((train_ds, test_ds))\n",
" start = time.time()\n",
" (train_losses, test_losses, train_accuracies,\n",
" test_accuracies) = train(ds, hp)\n",
" \n",
" train_losses[-1].numpy()\n",
" test_losses[-1].numpy()\n",
" train_accuracies[-1].numpy()\n",
" test_accuracies[-1].numpy()\n",
"\n",
" if t \u003c burn_ins:\n",
" continue\n",
"\n",
" duration = time.time() - start\n",
" durations.append(duration)\n",
" print('Duration:', duration)\n",
"\n",
" print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"eqOvRhOz8SWs",
"PZWxEJFM9A7b",
"kZV_3pGy8033"
],
"default_view": {},
"name": "Autograph vs. Eager MNIST speed test",
"provenance": [
{
"file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
"timestamp": 1530297010607
},
{
"file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9",
"timestamp": 1530289467317
},
{
"file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
"timestamp": 1522272821237
},
{
"file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
"timestamp": 1522238054357
},
{
"file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
"timestamp": 1521743157199
},
{
"file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
"timestamp": 1520522344607
}
],
"version": "0.3.2",
"views": {}
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,935 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "b9R-4ezU3NH0"
},
"source": [
"## AutoGraph: examples of simple algorithms\n",
"\n",
"This notebook shows how you can use AutoGraph to compile simple algorithms and run them in TensorFlow.\n",
"\n",
"It requires the nightly build of TensorFlow, which is installed below."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "TuWj26KWz1fZ"
},
"outputs": [],
"source": [
"!pip install -U -q tf-nightly-2.0-preview"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Cp7iTarmz62Y"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"tf = tf.compat.v2\n",
"tf.enable_v2_behavior()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "3kudk1elq0Gh"
},
"source": [
"### Fibonacci numbers\n",
"\n",
"https://en.wikipedia.org/wiki/Fibonacci_number\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"height": 187
},
"colab_type": "code",
"executionInfo": {
"elapsed": 709,
"status": "ok",
"timestamp": 1563825398552,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "H7olFlMXqrHe",
"outputId": "25243e7b-99a7-4a6d-ad00-e97c52be7d97"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 : 1\n",
"1 : 2\n",
"2 : 3\n",
"3 : 5\n",
"4 : 8\n",
"5 : 13\n",
"6 : 21\n",
"7 : 34\n",
"8 : 55\n",
"9 : 89\n"
]
}
],
"source": [
"@tf.function\n",
"def fib(n):\n",
" f1 = 0\n",
" f2 = 1\n",
" for i in tf.range(n):\n",
" tmp = f2\n",
" f2 = f2 + f1\n",
" f1 = tmp\n",
" tf.print(i, ': ', f2)\n",
" return f2\n",
"\n",
"\n",
"_ = fib(tf.constant(10))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "p8zZyj-tq4K3"
},
"source": [
"#### Generated code"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "UeWjK8rHq6Cj"
},
"outputs": [],
"source": [
"print(tf.autograph.to_code(fib.python_function))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eIfVy6ZTrFEH"
},
"source": [
"### Fizz Buzz\n",
"\n",
"https://en.wikipedia.org/wiki/Fizz_buzz"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"height": 119
},
"colab_type": "code",
"executionInfo": {
"elapsed": 663,
"status": "ok",
"timestamp": 1563825401385,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "33CAheYsrEQ7",
"outputId": "2a88b65d-4fed-4d96-8770-0c68ffece861"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Buzz\n",
"11\n",
"Fizz\n",
"13\n",
"14\n",
"FizzBuzz\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"\n",
"\n",
"@tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS)\n",
"def fizzbuzz(i, n):\n",
" while i \u003c n:\n",
" msg = ''\n",
" if i % 3 == 0:\n",
" msg += 'Fizz'\n",
" if i % 5 == 0:\n",
" msg += 'Buzz'\n",
" if msg == '':\n",
" msg = tf.as_string(i)\n",
" tf.print(msg)\n",
" i += 1\n",
" return i\n",
"\n",
"_ = fizzbuzz(tf.constant(10), tf.constant(16))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Lkq3DBGOv3fA"
},
"source": [
"#### Generated code"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "bBhFIIaZrxvx"
},
"outputs": [],
"source": [
"print(tf.autograph.to_code(fizzbuzz.python_function))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BNRtprSvwJgk"
},
"source": [
"### Conway's Game of Life\n",
"\n",
"https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "r8_0ioEuAI-a"
},
"source": [
"#### Testing boilerplate"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "7moIlf8VABkl"
},
"outputs": [],
"source": [
"NUM_STEPS = 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QlEvfIQPAYF5"
},
"source": [
"#### Game of Life for AutoGraph\n",
"\n",
"Note: the code may take a while to run."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5pCK2qQSAAK4"
},
"outputs": [],
"source": [
"#@test {\"skip\": true} \n",
"NUM_STEPS = 75"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "GPZANPdhMagD"
},
"source": [
"Note: This code uses a non-vectorized algorithm, which is quite slow. For 75 steps, it will take a few minutes to run. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"height": 309
},
"colab_type": "code",
"executionInfo": {
"elapsed": 147654,
"status": "ok",
"timestamp": 1563825336196,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "hC3qMqryPDHS",
"outputId": "56a095a3-28a3-455d-e95e-2c4c9dcd97d2"
},
"outputs": [
{
"data": {
"text/html": [
"\u003cvideo width=\"432\" height=\"288\" controls autoplay loop\u003e\n",
" \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAABdAG1kYXQAAAKuBgX//6rcRem9\n",
"5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTUyIHIyODU0IGU5YTU5MDMgLSBILjI2NC9NUEVHLTQg\n",
"QVZDIGNvZGVjIC0gQ29weWxlZnQgMjAwMy0yMDE3IC0gaHR0cDovL3d3dy52aWRlb2xhbi5vcmcv\n",
"eDI2NC5odG1sIC0gb3B0aW9uczogY2FiYWM9MSByZWY9MyBkZWJsb2NrPTE6MDowIGFuYWx5c2U9\n",
"MHgzOjB4MTEzIG1lPWhleCBzdWJtZT03IHBzeT0xIHBzeV9yZD0xLjAwOjAuMDAgbWl4ZWRfcmVm\n",
"PTEgbWVfcmFuZ2U9MTYgY2hyb21hX21lPTEgdHJlbGxpcz0xIDh4OGRjdD0xIGNxbT0wIGRlYWR6\n",
"b25lPTIxLDExIGZhc3RfcHNraXA9MSBjaHJvbWFfcXBfb2Zmc2V0PS0yIHRocmVhZHM9OSBsb29r\n",
"YWhlYWRfdGhyZWFkcz0xIHNsaWNlZF90aHJlYWRzPTAgbnI9MCBkZWNpbWF0ZT0xIGludGVybGFj\n",
"ZWQ9MCBibHVyYXlfY29tcGF0PTAgY29uc3RyYWluZWRfaW50cmE9MCBiZnJhbWVzPTMgYl9weXJh\n",
"bWlkPTIgYl9hZGFwdD0xIGJfYmlhcz0wIGRpcmVjdD0xIHdlaWdodGI9MSBvcGVuX2dvcD0wIHdl\n",
"aWdodHA9MiBrZXlpbnQ9MjUwIGtleWludF9taW49MTAgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVz\n",
"aD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTIzLjAgcWNvbXA9MC42MCBx\n",
"cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAQZZYiE\n",
"ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZQ/thx05aw0AAQoAAjZrf0Z7SQAFS\n",
"RBmrGveunhOj4JFso/zYXaRjQ18w/5BhxFIRpIkBeRXl9T8OOtGMbM52JtIMXIY7KRr49/IsKi0w\n",
"jJUK8Z7XIFmlAjIU+jSbWER5LmeK+6/diSLijDB3co/ebDgChTdnt/smJJAlFMJhzTUcdwoA8NQo\n",
"YBnpXwCtHd9MDNyz4x4zrqfgfXAXtVDOuKqK+ZIROmkudESU5HAc84NxG9mIFkHTHpfRFX0vfuvN\n",
"v30XneTe8IilYhOJYkyOcVBz9L5D3N5P2RHbPf8d2Ia4qkwGurGLJl8PxjFsKE4dm+f6WYtxh4/M\n",
"EbibuuIVHuFVTrhDBdjGsnlvGJ613cHSu4frv4bqhIfOz9nOKI/zhLw9zlvfAkAek0G+jTz8be7+\n",
"o/ndntGdno6L1LXJpdgGJYFOyZwDpk3suJqu9FKdCFsjDfQ4s5OYpZkBRm/h6ksvqs/jKOI7H7Eu\n",
"JEDtMn0Px1875SS+KLSHaHwtTCNzTTTEE83rjSnRcLH2qekoCAzC/F7u+tWoo8/5q7AU8ZwbFyde\n",
"C0AcLGLOTLX2dctD5sMzDYlYtX/lYiEND4SUALBVfbetB5IH67pM/22hp7cM4zkyUfekvXZeKUpq\n",
"ihxpjZ/b0GfRGel+eaIkRAMer8l0HHBl4xOpdwEUiGEQqacmsmAKA7/Wn0I4FZAkAeHbrP6JQw8G\n",
"T6oLn8jHc2YBwe6YY+t5SuugRFwnijdFTQ2IYMHZ9spzZjJhn/lftFm13UY9ay8CDty2j8dXZfss\n",
"pdN3RSB6EMFrirN6yUkoxa8UPGBKHs9MUFO5MnKDgADHT4JhBGInxUASlDV0lsFB0GH9ED4tkRc6\n",
"7SnaMmZwf9T2i4a1NSsheM+jHEQWr9fgPDBABuIyToLYLrnVeLXqSC8JMeZigh4GOpQKyiIsG8oa\n",
"f6kiBTwG/5RebTqU6O7rrQLj5Wd5YFdqaacUZGByo8AxJ60NHIoQcxeNjsWAj6m8SKd2+g3en70+\n",
"zVQW9HkvHI7nnRF3FhwhZYu/LvproEPyWSYykJIx75ojR14WE7oWSjYs0X2AFiwEouayVGii6owJ\n",
"gdlCmnN8HoqT5PPnaOWG7mPgq/3meUuz982ZX4+4VMage3Fe0K3cqRdKLTge+gs4pyQbSUIdrgo3\n",
"4P4R1ejF0wAW1R8YjLZz6fQUzzzchgNN0t7aa8tlO2yDCmII5BbaYJXJrRvBm8Lb1m7TLILNalgu\n",
"RMjYD4Pf/P4iQqWsBEdgB3p334RMzrBfcviq+49N2SRQlYxV0SbSMdybZaH+vxuw+VyvLt3ulEcF\n",
"rmBwnxL4kpGATPv8mogAAAMAUMEAAAI7QZokbEEf/rUqgAYz+kaAoYS6oZnCZBWChU49QzRvBVh/\n",
"3Pl1tY/3h6ui3wW2qKCfpdwQ1h/uuKhRazpong7+Xsbw5g3mv3E7I0N68sUiey8Dbt0hMUrR6zYj\n",
"YtzMQ7gEdgcbbOEgu3H73w44JvEzvgZ4iO4Q2Kwp7BHY2uxxtdUENoG1kHXqnnQawFSCHZ9W6pRZ\n",
"ZX580jW/ekv7tzX5SLrr2mknIiIEL/9OqO/hdKRWyIS92L0VbeMgboQPIpdXZEemH8ScfWR641oo\n",
"Kb2ZqixayrynX4qeQdDAXvtKdnTPfgTsOJHs6zrnaaKb6SpoCg9ffzFUfiQ1YwLPZpLhwkJ1F58m\n",
"QtliSU1LCArOxcL0CdX1xv0PO1XbIga8mvD2ON78HrYIlpd7r9MIJUgGiGlRxLTUITjvxtxjLYBG\n",
"TBzSQ2Mqy08Y4xvBh9/AZrWGoBvplKVOooBAXsS/J3OngcAaMApnGniTlEgacIB/4ihqQm9Zync1\n",
"WrLEldONGr9K6gbteZcFnK/hoe6B53agN6YwjF+Hm1IYltzK42eiNQbmeo0nT6xx724Sek57Pcpp\n",
"/+64lZEYNhMLw61j8cLCmWJLqJ9+OlV3Tu4kvqWM5A7mBmXunK5EElFvFoiaHvfKnFzVKUZHVN47\n",
"dwwOu2bQK/GEFcs57H1A4Ddl2JAlJt4ZWgrJx+vzAgyhhcl1LtQgQcd3rX3aPisDf1CYETnay05i\n",
"xe8yUL0AVMzI07+lqERP6auGU//nlrslfAAAAS1BnkJ4h38AGAsZbANezx+IWo4Ni9MoMfKTC08P\n",
"cqaDTyueeuPLGgFgW9U33kZ+Bw1xhP+VnfaIAfTxYvkb1WNMMRMsh5PjwSMCmaFIlQvFeKZwdgkf\n",
"0eHuoCcg/XQXRqCvEyyYU7Kr945fY16Tu/18Zd8NU8RAJRLFspmBVoIPZ/aTPmSIXDq8KOSzL6TG\n",
"sWN+V8RKxGwExIfHZpdEvHu1tOeg+pVzKTracfnYiBxxlkuVIyzOz2mFv1LQ72jZQGocAdWS14tD\n",
"EtCsmNljiTGQDRggnoajq8kpnFHws9ZMWmcsV4dQvczexFmx4YibNvvMPauj3CH/KK6FXvQFumid\n",
"ftiga3Uno6si2epmOuEVTuVQwXsgCmOyejpjAiAjZuUS1zq40WginD1EPNgRAAAAXQGeYXRDfwAh\n",
"r6zZu6OyBrfB5mVsAz3QNRRqvrwAcnFznD7NXanOaWlAADNOwlJX/xGmO79sH9XeNRT/FnLuEPBH\n",
"1GJhJV/Xt2R0YziQPpgXV9BLMr5IaMaU9R2CpgAAAPgBnmNqQ38AHhCAmS1kGlkSnBkADoOXdXaF\n",
"NGZr+Q4fCvQ7bHDsrrZk+gghfDnB3EgAw+hgyCz7QjPCBdm4Oua2VioU2d4nUZ+UABLNnRNNghIa\n",
"znH4EU6++iAxhcURNicOGGgil2sQO5YirsL6J7S/TznXYcILcn91E9qrSkdqAKeiqMttbt/NlBlt\n",
"zFtTLIQV87eeTgQtRSaGjNkYcjtT9zsSroMxdQkaS/rgzWfPKqioru5///iiFvV7FHhGNapsB8Ep\n",
"xA6YqLEIyfxd3iBKiJ3g/96H/WMQrMVl8ykLYh6g9L/mEknpMxDRuX+/d5vuR5TJpN2l4QAAAY9B\n",
"mmdJqEFomUwII//+tSqABipnkgGrJGhoF2xhqIGFJgrTiV28TOHP6iMSZwA4LzauSvgcy42/qpKz\n",
"PF+GKWIn2EJeWsQWOqhnFWAeu8Qy08RHEYzw2BIfhXKPnsvQ1D45gRUsCZjYq85tliORVeVqHlvt\n",
"fzWrMqI5f+favhs74Q/1bo2ebSMVUSFuP3HPqFVDjXrf/wjJSgWTFPNzCZtjDghfnhYgAzPVh4sd\n",
"mfpnfQi7UGcAu+X0SPRW+sCzjBKyZsabYXRLvCvcRgXcWHRJnqJZ7DbIL5Ahmra4MUmiAdrDqxi1\n",
"yixz8Ge2MnwDKePhHbASj9FgVyabApZmODkYAk9x2eNsu3NC/GWuEsOYUEJXb3NkJ3H0Ehpogb5q\n",
"/7IADF2Rk2r94PZTFE6TdqRa+DeKrhf1PoBJxN2bNx2sA7Pci476Sn+ZpPsAPTlXaikJNRAhO4tD\n",
"lakPd29Edmfvk34bCqY6rFMuCfUJ3yzCy+VRKB59CtgS68dVzaJO/FxZ2Of18yjXsScM2fL16/kA\n",
"AADDQZ6FRREsN/8AHa60qBaQmR4IRAA6Dl3Sc6VtGJbtr5vbN23f25BY5Mbt9ZodJaqeGLgSZDt5\n",
"tMt3+exLq/o1or+DyDOaUjfDuI6HO9EMKVIFrK5bBNySwYGQ9ZOLXviohcSZAskgQCT8YbljWqgY\n",
"W5O+m+Ip3OoA9JMxAp4EiGRPR1hmuQDeRomyGX7bvvzp+lmhQcgx50Gtf2FsWph71RE5OIfz3vbU\n",
"YPJzvstNoHMLjQVN28uexbTk/wUswGjCQ8u5AAABFwGepmpDfwAhvaAbJNR/9ddNI1ZNZPr5vm6q\n",
"XTetXH7Eo8GqFltKJbOb+WxFxg1OZ9LY7Pm4G1n+FvJzAc9iMK3kbM6geeeFIdRl75A0UZYsXIff\n",
"dQXiQxB/kP/GUeJS/ghHdsFXhovY2ei0jBYXhl7XCQdiM+OxqVpdBNYdLY+vhvtTydDweWAQhmfY\n",
"3fYN3w2o0+YtvleCAQNIu+tN7OfSeOifT7EOLQk4YDYkvT1QcI6scYDf1en6ihiP1DSq11Clzx8a\n",
"ja6cddGuoMqDaNkxCF1dzf2Jvz1VA4BpWPjukcCUvSBL5Hjn5IenmZHNevhC9Ri5TKMMAK1OUZos\n",
"eUJttkHLI36Z4EqqgVQeXc7fMR78LG9GpQAAATJBmqlJqEFsmUwUTBH//rUqgAcd7WUAG1wL+eMP\n",
"5NbNjI1PanDtCkQqkSzemsYEjSdqyjDQBhMRhcVkBjrLnQ37QRY6anUo9HtaOXKEvV3Oq3t3zJnU\n",
"VnRnO4+DsYDha+hVjf2RQfz8iIHBAMZBzDCidKRjdK++FyTTJT//wjjoyDzrLD81EvvOEfP1hNq1\n",
"E7Mf/LNi4VzZp3xaz5k3oYD4Uh8itElOoUglEcP1/ghF2UcJA9hOtkSUpVhA8+T8Ytc1zpVMfYyg\n",
"QqbyRa4EvI2+PCgNWtypZmPOW/fUb8LPNYTg5GLhzbOmSjYpenEUzkib0QksNLKbj/E9aHrV1qHX\n",
"qXiny+3UUPxYGvj/pDuYRozh1EchMNkv/eHEkrQhTQjnyxDirLtyAwkvICbz8w9UK2AAAAC1AZ7I\n",
"akN/ACK9oCBuM4cceanCEEWpV8cuy27lpLcHp0RFJ/onjSEljOG8VqS2Rkf30kIRre+KMlNGVcvp\n",
"cL4orO6Yp5KjC/RRBwQz/yE8UKLNeO0Y0FFhQfICXcBtO9ndieTXXlspFHuGf4S6CeBKlAO/lDFn\n",
"Bm6rf4RqP1vvLrD8KUBlig+AFH77l/U3BNsHxmcjURJ4rz9SBUp3dWhkBmKNCP57UtC9bKnqFyE+\n",
"YvACZ+sMCAAAAZlBms1J4QpSZTAgj//+tSqAClE1egBKEwbZY3t792fWy96pbeQQCnoXHta8keYB\n",
"6YD4iyrisk5RAGXAP8hftXkqsIp3gIADtqeyulunIxMvA+tHyMYI4mH7Ktx24JQCDLGwr+SW5Lfl\n",
"LFzLN5Z5EpfMBtjuN1e5MGJfkKE7RLofReD1fgshPg5Hiu3eNzKNtXPqCUQOQrANHyjLVDHW1On8\n",
"GbpMg//3+EW5h//MyUrV8C3bm65GCPAdr+IiAQS5PLqRpJaqPFXYImLzCfEF4IcxGqfKzcnaOGUe\n",
"P5zhUa+at6SYruNLfSBlr3+mvyhAAxPUBpQBX3a2ZIbz3QLaxiA/KmUnrCDmuWAQmEAoRWFYDkhB\n",
"vSu304LzlIj5BSPPqNvyTdiIsLpzAu+SwxleN8rOU8p84R24aRhgQwchoF64pWQkYvhDlixS1XkC\n",
"+1BFsz/ugThqWNrj6DMWcUAmd8tN3JWA8raGQmJpBH1Zjd5483GFE2+DssYAdvIzFktdYvwqJy33\n",
"xqAAiKb/jZmChnRmwaKmyp+usNPBAAAA+UGe60U0TDv/ABgTM0cFpiU9S5COo+Eq1a5EDpKRq+6p\n",
"lSs4dhBzMdhHGYju3Syu9sir+n5TA4S4EozXRjp4djOH9s6Ebl4mnuRqUkAVVyRRxloLXXdAVwvm\n",
"Kw2kt3nH3KtGiXPZtoKRlLMwsYrakek54VGjJMSSK7z2j4bZfzdU5fWILhtGELYhukSGMv6CXtq0\n",
"ugZLCx24z5CJjXHZ6aJugoOXVvLE5AMKcYDe/LowGji7OLeFgeB849mfSaUGlnh7jxuhBOU+fRS4\n",
"p0ITI4vXzUUR4XVTQrOXBNie8HQwoivm+WRv0nW15Zl5mZ7wAnqm6XldppA1IAAAAMIBnwp0Q38A\n",
"Ir2gIG4zgb64sxYLzhi9P+r7lwy6Wa7RRkAjTYM9mY6ueOaRzgw6T2RlVKQ/Wnw9OUPsoB+98v3K\n",
"7Ai/8Ku9oiX4fIaC4XxFxl+0lQDznNsd4UfPo3AQh6FoBHug176P/7mBbtXW9HioX3mZhTRXJOlh\n",
"Psk7HP1i1klJ4f63KMPuZvFOjkq75Z+u+/aiOQvmn6+lP0r2vSaqs7nxNSGwPqSwNXaUgQz58aD0\n",
"pB2v6eKf+Yy3eGu8f7HHrAAAANkBnwxqQ38AH77opN4Quy1TZxAAOg5d0nOlbRa1oa+CUrbGUKO9\n",
"s1K1K60LxAZlk8ZQWiHU0UUuQDnHAAyjelIcwOj4NipQdTlRBT+HrLVCVEK5smCT4WEyhlST21vf\n",
"pS9QIx6rrJJt1ZwRk3fLMy3lh+GbSU8p/deKiRgvPKu2y5xljT8HokdUfoJBN0b+9AYNdPwZxzfv\n",
"wRj3rjB+XbCQdH7rLOmVBWtc7YBBcmnLfJ50Xx9vsPrIGyT/orCu88gDS7Q97WNMWaRoINuEV0SN\n",
"7lASQ8YC8xeRAAAByEGbEUmoQWiZTAgj//61KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4gDd4k3m\n",
"hfDQA786lCeWa51/s1J2qe/kkvnBjg4L/5tqqnPuWzD5CtqsuCrBZfD9tieYn0V6h2QRjHTgf2S7\n",
"KbBJVduRkgXz0DCyLCsDRdQx7ZVeilFNQPYHPpL3dFbV2ZQLhZ15DCVv0ijUbfdtbaCxQWk4hFwi\n",
"4Cl7Vcv5eumMKNjbBf29eX+p4vfxRMeLxQVGLH+o2FLpf2SZwh6nFX8ReHwFB2aNAZojees14KLO\n",
"dDXVOKLwRfawG/F4iTHLNjIHr9KJ7RMP+ZW2v4UodTEwj2IkfoeugjPYygxsYBEN/HIWo7Lp4BiH\n",
"W+sGNW6nzMrLHeZnfPrIXJzjKMZ2dMe3r2TPoxLKTVgPHlFgXbB9gOVEkvjr1YtxEt3sHivjr7TH\n",
"zrmzrXSS01xk914HSqt/CnYSKPxa2MF69g9I/BNJSHdHCdNGwRVm5U4w/DYDySkJOTHhPK5xLTdI\n",
"6pomON2J7Snu3IFO1cMuZQAgHAwoynkWURtTVoyQbA1o0XW4HcVte0xmLSUrxW27KPhiReLpDIah\n",
"P07+6UwIug2Iw2yxWwAAAP1Bny9FESw7/wAZUxOT3tiejYgyJDRrCYHaMUHhX+buBbaoqZ/1iUWs\n",
"Jb7slI/imiQ6OnWj09SEskbfc/zlMQQ4SNXZauWfHJ95XYh7wMFGgh1p51IG9qMewyJwQS444Zn2\n",
"viLgUg5+yrpXHCf0t8/9jDlbqwjDulbT62pdxpAyxuynsO8RFT3dUKeSE5htp/jbraDowEdpXZyE\n",
"hG0WYkl+RbztI/PQNZCwZsz+nvpxvKr5XHM1hBpXHcYTolc3yg25EknXG5iovx0Y9EuSqthrt+Xw\n",
"mK43mYVJUVC/Oh8GeZYMuS8/kSjScKjb9J2cbfyAxgmK23G/LX345QQtAAAA2AGfTnRDfwAc/TTk\n",
"s3FNYSmNHdPgDfXQC1GBEwJGCqSU6MsmeFhDrrArJ4DXkS7h5Olwl5LsAdAjNSMWnsyuwfwlhiS4\n",
"Iu9nXiMR2gsFQTdJfxAGWv/oGKrfOpY9OM+oH5mmAEYRbo0uYIZjYyyv9H1tg0RX725ktocEeT9I\n",
"3B3Tp4qYCOAxN7JPiw1LGqnL098ntFu5ng1+yPoA7ayjGtnhqUNzDdxHw06qdCQZykRFXaAS2mFv\n",
"lmomA2wH7gnlU4hH+9/QtYxMog0PKOypGE94HJSUfoT7gAAAAEEBn1BqQ38AHE7WHA5VnN1RP/m4\n",
"B17wBGTsyVXKs9N7WlI9AxsJJ7v9zVkMjf6pvv+Cg6JoQ3BLOK7r3bcONYUtZQAAAddBm1VJqEFs\n",
"mUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWKggEMwd90EmHsgCWksYKFE4Qru8Yv50\n",
"LqOKJvWMLHGzKIf1mWoops1hD8q4hCLJMEdRItKEcO/AvOw75DCgogAQMHz94YdBlV1FB7/3PGw/\n",
"kvp11c7Zd3bjgbTV5f9wCrj5V98Wrk1QkXKTao3xn1WeAORpyCtFJo3KIIzvry0ktsvXmShsZdHK\n",
"SF2Q6qY6Id0i1QRrrPRdF2iq2m2rhv1eY7FLgTuR+kimJsshiQFr/qQ4tOO2msQRBI4huY4JSA+L\n",
"KftHgweMeBwJfCg9ocoILqar/ZxuCC1Kx59hrQRJPfm8amRIkwU/k+wKJNYh9fLLSBsxlrg4XoMn\n",
"PzXBXS36HS/Vq/PUU0Saj0Ks8oGCHCVcz3eoIxgiU+QJY/DixHlF4+MYR1JrL+dYLi5XU6rOa8uy\n",
"cymZbC8fCrT8nFmCuYcD3DNSzmKt2Ypk8ahqcNxMHCCE377w4QcAAK8hLicCDiuo9KVio6ugqDQM\n",
"DiWya9QmBn0ClIbSCznyVdfSZyODo1gjrJ9IiCMcnWI45hcgB0F/w3f4fUDX3TFD/vbMoTmxwMKV\n",
"hWEq4XvI4IEAAAE5QZ9zRRUsO/8AFKVUcHl/E43Gt6o4RZvBs+iAp/X/n7d7Pz7RdmO0J7CPEDVr\n",
"YOGCwg4aa5sRnK1DwPx5sIYzP38566ezpK1+yb8tpnK38Otysb+fPORXq89pSQ+5zLmadq08PRPq\n",
"ft5b+CuHdsaohxgMdfr5HBiNNodd0VK8TNpXmgIXzYR5RpK7ScM1kMS9Nv/EnJHMV/HrvGwgTDTj\n",
"k64XWbP6seQRZKb98opQD+okWzwHsAFj5ehr/ekl0IlB4NOOkEs2vqjJoc0vIcwkba8FSFkLe2wm\n",
"HNG8c/q9E5Tipy3avrHlLTvT0bjPkjeD4HLfC3isImW2RvjzyyF2TiLuxINvE8y7u04RbyNnhNhC\n",
"J15BQDsVja0XtFDfnnr/h18foOkLRpLJ1yQTMBboYsOrVzSZ9GDWwAAAAM0Bn5J0Q38AHQXz6rvN\n",
"uarixND043ZCNdAAIHUCWbOjp5TUpZdEciERk/s2Hj36k/1QHuy5AO7bU6FcTtkwLNXpp4kEhhr2\n",
"pj14tuqcy7uq8XfveV+qzHFw516IWJuk3fnleTKVnyg4EmdGVkh8uUm8KAFIin8/UzurGkP5FXB1\n",
"JS0uIqtx2mbD94hCpeHMsXHXmWbW3GUD6bwQzUCwUdgGFWWOBIzHIH3jzzxIIZ0rnTzx6fd8zSRM\n",
"hMrhmhy9AElVESMBSl9RUVwHxFBAAAABSgGflGpDfwAhvaB1qIOto5yaJpOYSSkbksLCkPuZStd4\n",
"LeT7CV/DcB+jLm/y8AhlFfeod4crFEXxelJR/fWiWC5cEAQJB3xoICKkbqYOm6EmFwfhOJrnHL3F\n",
"i7egoJ4YJywxTcfWExKLj/7q5Qta5s9pQnji3v49xEhquy1bNbsP/0r8degDcM/eCvveCCuWJP4W\n",
"kmgZOsTL6w2RcANA9FiGFsZYFgwwIJNSoi5uPhHUWhw8DgpZUJJwhbcwAlrJ/XkpDgMQdv8+KTaK\n",
"5RNrXWUI+DQboZuQqh0EP6Ucm1iy8BiBubHVtPfvfM6aTMlQH2sGDo7kxk+QnIaS5zzgTFrv32D9\n",
"yKVtBoqoPJ0AuZgM4FsUTuUjy7Mb8fU+FNoSPESiOFS3CYbvMWBzWtiplx16c8G+2sTGiL+yia5h\n",
"U5UjqF9tl+DCrXkPmQAAAhVBm5lJqEFsmUwII//+tSqABlvipo+ln6jP3YEZZAIeN2gdAdBG93Am\n",
"88+PBAP+pBG1b08i0fIFrYTfZkz4SYTuxIQ1JlthBpef+blJppNwqif1piWVs/t6bCj9Z+mNxSeq\n",
"fY1/wgLfvSZhz+cH951YQ+3lZMxDj+AnlpOYgaA5ONYw7fbC4eXvAp07e1QLTwt7AKsxs6j/dp/S\n",
"ROqifCEiS8aS31tyrNd0WUbq8QssOlpj1+9+m64Uuc7+f7EFYNlp0SQRRU2ux+5kBFuUthOQf/99\n",
"ODAIvGEvExgFy7U9xycg96i+XWorpOkUsmc8UuZbMVhIEf4MYVuxmTzjhiOVDlxwcksj2gNb3xa2\n",
"pmXlh1zp/jlUP6lnJbCcR5jJhGaBJ/wuH3P+rOiJDpAwjSIE4agxxO9XGnmQRqhYjiBkbby/Qs/C\n",
"0p6IlpvwhBITpwXRBm1mH+MtJEskEccmYaNT1YNO6b966q1ndwWmG4wqG8yXMOLAMIGnxTjTIpRG\n",
"9a5Z9Xdl+HR4ndQhvFfQ+mQNsGUdDPAaOtDr9NfsDESdrHz/VFsWMxlbozv6ME9/FBsTE8SLTZxK\n",
"uKA7LtdEmFdsikvrVwkDRWs6mlddIWSLEJey878D400I9Bm2F1YzYF8hIer8urpKTRWH3dl5Pnql\n",
"OkpPyvm3RplNwN8DaGYvFB3ajEHHx79ej7jTTF7j2dZAVPOuzAAAAQNBn7dFFSw7/wAYtYg8t2YJ\n",
"aBl5mT7LoVquTMWPsAY8JEk7n2Ltj2VU9Y6yhnUjGblNmyV5I1tDP1WCa31R20KBx8ZAPYjEjgAl\n",
"IBPsF6gwEF1mGQPgwIt+DQ7Ltrn+WWljoOZe6qmL3ODaEJKUCy9wZy8Qi5WMsDYzpEybVU1vipuE\n",
"rsjD5epFom/S3CRpP+JRc2SuBGV9X135AtKz2dAbEFqb0f/DUfvRpyE/xar90tpMsUisBmDyfPqC\n",
"QCIWsyVA62u0XX4SHuuo3VkmdASLaLWJS0hWsThucD2h8t0xx4j3t8tQeFkAoX+vhWm72BA6IAOh\n",
"cP5AynBLYvgLjkBSaw6ZAAABWgGf1nRDfwAgt5i6arm7oDsF+i9EHiOJ6m6rVkYAHTQbG9yseMuo\n",
"2+jJx58xpeovc881Wv+6nIPwZiRTONb2IQaBwPwYP/UAnKjoweUWtNn8yjj61Yi1F5n9oYReT9vo\n",
"YNykd6+UIhqXBR69VB8JEqms6DNcB++Z+7S8cRY1PTjUFRAm3tXpZtcqOC46Yje8Z3mZdWtke57d\n",
"wfIWf/bjH+PQoHPWtMGigrlGqEUElC6TETXz+nB7X3pF40yVazdjxa5pCPS8j1Bqo/RmILtftGxN\n",
"Yu+1c8QTzG5+3qHYIB5lZeEW8bNhQmHlV1zck8pKhAWM+UMUo8Yo1gMDIjGuUuNGCTYOoVand7oO\n",
"JxBESUm+840sI50gEtqO5mhNaTQVfGrhYgQvynil8I63rBmEOncCHtkN57Vx9gduQDjk6aOyO6bY\n",
"qsBt2jiwg3SW9pmMOjEKBDS6IfMiAxcAAAD/AZ/YakN/ACK6K1xrl4Eswd4/m5m3eDoe6aKYRGzt\n",
"qScyJrEz0/YMsioeM46osJc2N8un8CXkVjpps6zgsf8LlkG70ab3ccrB+um/wXzisesiYCwJDgAm\n",
"D8ODYrLA2f4XQyaEvxMLwdPggFdV9SLGW7IaDs1Gj2MKL95CD69ggFd4PlXdr+MMXaKnRfCfYej6\n",
"jyRkJ6YHIJryGsscniQRwJ0d+J+1KTOriJZQomY6moOkqhpxON7UIyt9lzU6HlHOyQJ+oRH5iOIM\n",
"+hKNz7H8znQxxv6dKCBY67rZbPlwYKywoLx2OIjAEQohlh7LdbGhKMy/zzEiJYFobhp2mH1gAAAB\n",
"WkGb3UmoQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2Q4cnB3oAeksXBYLQl6DX56J1l\n",
"w/mHq8WxaGt2MnAvQ41YNYO39iE6FvpuFKpW712yS65PLr83LJiqo7HZlMfRzKZN59Hb83g9Yzjb\n",
"LItfty44d54BI12++V5xh28HT7V7r0Y3bFC5OovybNWx1HQWDmvmM+uWQT6BKmA1pblkm0jWUuJ0\n",
"KAyepKH6sPnyIzz9TF/cTcVBDLcJ0ebq4QoNf0i/efDFq1nH+LtoZFDiLpeCwZkCLTOE+JMjcVxC\n",
"aWP/XfyRHhNANFDKtoVePLPasXuBVFa5xCh3bB99SWFmaQdxLlk9zHTMNOyCWoiRa9OkdBShrOe1\n",
"dfGrU6t4YEao5nNo7umRhNJMptOYWcUtCbSBQmV/4G3c/zgmpJb1N+5bNROg3nNApsFhNWPnDxXX\n",
"YEcAkKEAAADvQZ/7RRUsO/8AGBSepWN8xnNsxE4oE6H3s58lr1m+iqw+EfUFRD+Jna0+Uvzz41Eu\n",
"ATVBokoBIC1dZOqsBeTj8Ij9FIuxNitjsFqDL+DuZwvmGihDa0HIS79MTSVw/f89Ulk3p2M2jbij\n",
"TpCkIItiAXbWCZspatvMx2+GoOmu0/Pjqc6iwrXWXyi9/N9Jj+yY/ClUEyj7sTv82Y9nVf++GCrf\n",
"1w5ltOrH9rRQKpUQaVxp4gxcgxC4qFFOgMxs83r/WkZSqY9kO/9UmmCqExD/ljnRMUJvxp8FxL1d\n",
"H7PGv4WLI5AeltB+MOGIOr9NYMAAAADwAZ4adEN/ACG6NY+qIzQfcYKCb0AhP1JJtQboSZcB2Ux6\n",
"0kAZypUjTcd/OmJjJuZBZL4W6I8Qwzms0HJLp8KRrHdk5GfU6sWQ2Z+fhfAzgzC1XgPD4QBqkDkc\n",
"T0sPX8iasgf4/DARkJP486Pq1cqH5kOYBwnnR907+n/qb/xaeHwouVk6h00s/qlqepq0S1p/xGR/\n",
"GdINVBgCemrU+PPAyI+EQBjfU66sma3ahiVaLQtsD7mxr/vZVvwLqa7Chr1J9NZveiHKnAzIMG16\n",
"G9Gmkk/8FUHgdrIbZ2heuBDh1KQSBCztE11k+ocodRJkiMj5AAABBQGeHGpDfwAhujWPq8KUOIXq\n",
"Yi8pfsfzwlVQDEG6igccpABq5mcqZlBxZf6f05WsPP5oiGUHFHfSykAR60y9PVPsKziKYov/dHwR\n",
"Kft2Arvz4qT56TCewQ06i1++DP3k7arAvxqk9+C83xiDX/XWrTHQ1+jT9fNei76g+LJLvs+Z4UVk\n",
"oEaQ3c6fXvOR9+Md7sWQeZnYPXpC/0w6s38iG8bM/+n0jsTdTFeBwE6YfrCAsv/ybSEXYS5eoPM3\n",
"f/HRzfWrUb9MZw2WEuoxs0K4qVyNiDTxcyb1DdadbkuzwkaFG7T2ZM6Pebp0YyXRqckmxx6YTGzB\n",
"LlKwKmWHeooj6Lm9LlzVgQAAAaFBmh9JqEFsmUwUTBH//rUqgAYrWZggqZs1s6MH6FUT684nhne8\n",
"ykZKf89h+0voVegpTcVlgsFoS6xwNTcMDCv9PiwISM3bG5gmdpPxwsd2af4u9VMbVGyE78HSQ5M/\n",
"nbkySYm5CPjed6c1fzFNEjUv+hlxYNfv3cPYnGT/Yav/5erFhxatniKB++1xw2wwwm3hwteUjAt3\n",
"Bi79ySg16ijYqJM5fa8+vosVJZysXRlnbW7/ITdmkkl3c8ndruo8FzJ7m8m8z0kOYciXI4QIL6Xh\n",
"qroOcvOVcWB7Uug78ZH3AowGQXzMbzVMrLD5Q7gJi2vHbYwWBG8EpVzYFtaj2m+v5trtiq/wJKtt\n",
"WosqXvVBFnxrWYQFjXg41D/ASyQHPzn2WsqemfWG6/EDepgeax6MAFQfxyDScuq3fNmr8jf0net2\n",
"tjnK9AbUeZfaZDCLHpnptMZuk8clMx5Y+UVSA4sRK6q5yL86vVu3TWQ+TGs9ZFdT4m8kNBPSkwSz\n",
"rQpsGSml5JPzqe84pJi6yJhqfYRsb2q5mJ8tkrUntJCF8lR106wAAACuAZ4+akN/AB1RsSI82HuA\n",
"EDVZr5mUHFl/p/ZTcmoRWj4TfRvTsYw8OlDJB7dvZ/vcXyur4LGUumPqBQUBQHfGq57+bI/8tRzs\n",
"Z+nHU7WH8qJ9BM8/NBixjH12m2oVcRb4XvfrX32V+Y0hU+0j88MNPEcdX4rv7aeeep8jA96PadWJ\n",
"mSmtmcZfJIFp4fz7nGsOeHvsRUbV0MKDUYmKN+mrh03bThLfJGXI3U9Tnh+UAAABmUGaI0nhClJl\n",
"MCCP//61KoAFm+ceSLbmAtKM+jG0tYuAZBSWLg59auQBOS8BoT1gHMsjZkIU234iG6WAeSbLJEu0\n",
"KCLhFA+AqaJQGzw142KKgdSAFtORqvq8YepvegTzCCnS1DU11oB/GUVDtDnboQEryLd0x6NUSSMN\n",
"cECL9Mzb9QebAeTbVcgtE4xPKr7FEgVH4vbNIioC6rYN5svm+n7fErwoxd1c4B0MbzpTJ9ypWCIt\n",
"jDqP/6ecCXKe8Ac6gqcpyPRaKmFcKdx7byHCFs3Y36UHxsmpasB5iKonQtfou1T7ViPEDD+TNshw\n",
"6ncI9FQOyx3EYxNs7CdmXQjjuiQ/hVztgan/8HWeS5jp2zgzBv5BXUEnWn+A7+FBONSn2LL/uQ/w\n",
"xRZTcRa0x52ow/V5cvgKu7FATp/RCkX/G+w1Qnp+0VyZbVkCutQ1yOnQYxf79Uw65C1zWPQdQMP/\n",
"K+VS6vPAs27IKeqUeSeiBKHv/3isIgE+rjxQbN9Lh1YW9R/9r++mSeHrs60NzUtdlXFG/VIZkaKd\n",
"XMkAAADXQZ5BRTRMO/8AFlm8HmElw5CLBq61UEezfOfwLuaBDj371pFQE2TaGfrDL2cPvWN1QZqb\n",
"tmH36IVd+buOk4nAS7OK6LGtZWekVP+ro0ezqUL6LNjplSKI15AkcuTQweCsbYhrSLoTsRiawYgs\n",
"mv975sfbTCY9L8bxROvDNcwG30R1+JWvK+o/hwf/xA32LhBb08HGKIsZFejSCR/ZACyPMiASYPKQ\n",
"KnKHiabUDVxwGq+/saT475SIsPn2KAHPd1oy/JYI5la+DZBAp1lqCWQj4yUkciIB5BAAAABzAZ5g\n",
"dEN/AB8V9DqLglnogAnlbAbcaeEM/+Dr1d94BLu23/b924ZA1vKLZ+NWO2PdXQ6go3Sf7NA4nwhe\n",
"Jfk07l2+PnIu+kI9sd8bYLUmTTByKGfoyEUnQqTPIf5dfjB+AgnVTc5y8pWcKU354gRsJCt4lQAA\n",
"AO0BnmJqQ38AHxX0OouCWHEND0XeNAIAEOFUWlDAA6yKdnA6h0XJ5AHh6k3PwK41LuRgTA6dFitc\n",
"eGcLOFImUAXmZeNXd8BBiP4Y7WDb/nj/8t7UR/ChuIYJmbMzvyMcttz9Od2nvufuLeTpnnGxlC5D\n",
"sKIQ4TiAF1Zf6Jjc46nP71VK4g2t6fmiQijizaslPXbGXByTezIrwT4YraOsiMH4GMwabs58JhIR\n",
"tYealSfNunZO0jU9FNwqBbfEknuQIRSATwmWr49+JU7MtkfWDJ9lAsDVu2W/43LTVqxccM6dY8NC\n",
"EBnYMhV6U9uYbKYAAAGwQZpnSahBaJlMCCP//rUqgAZTWZgI3NAzNytjReukCJhCqRIQrgVE5TFG\n",
"RpO1ZRhoAw39KCX0FTF/pEpCWlYTREK0RX8M+i/Zkz6IOh5zRR0GMJniH0SeRA8U+ZBIRrL9Hl62\n",
"8kZwKv6q5Netv/8gTYt8wrrWIwWANbXHJaruY4G39urxvB/yx7ozBV54M/wmK8P5AgF0ljjPQAUZ\n",
"DnLEHwmopi3rWM++lGz+7pSmghGU/3PNF3AxzoRutm1cdRdLqAFKdPRrKeDtflDHW39dHMmsizA0\n",
"JAD4HEW4vO3o1CbLX2IxlZFPJGuT1QOtzPR7lO7pJCxfeGJXFchlosXXXbYjZoXRMBBKcHqbIWa+\n",
"lcjl1FcSEXbk84/WCNR/hEiDPBQ56Zc4Yg/Uu5te5H7B3WBkQkc5+tttienjQao2TkWT/tLarBIb\n",
"fSMA+83k8gbv1oyeFIIWqR6ZYarMVbzfFtnH/fWhWkYB/el6Kk3P0OPSTUOVwdEnhQ/ztu0l8Ij9\n",
"PRLg28jDAaygyMt+MtthW/hM1h+aETPrMcrgZoJoV2dKCm8mLdDu/CmksDfLJBRBAAABQkGehUUR\n",
"LDv/ABi1i6Ag4bMBZUwXqVJnyx2PYc2F7FCjvy82YHTp5//HJrbZhCcYERymRfl1ah1T5z9noaM6\n",
"FqCYiKh/nb1NKcv6lay4yu1An9EGWzEXMRaTXWcwehWRMZky6GX2Elv0mAOhcWIk8WVG2FWKKMhd\n",
"27a8KH0mx5CnVDu76Igw2moc1+yPfDPZnRGymeVWDMSj1/TY3hGgb5hmSfANHPp4nyrFETtH62Dy\n",
"FIZnfZ2tua96PI/858zqXLfYaSaEy66elRjPHGSUQ+kLj7sT6e2TgQoh23asg1dvl0lw6aW2KtOQ\n",
"yQVjdxBZzehiTDj2VDDo/FI5LuGH/jfe71B2giPdfSUEN0GwZPmh+oBJ3YPtBDdEXjvqGtPnj9YN\n",
"o2RsGDqkSW3oa8BY1cptmQPEHp1SMBrX83w6xtQW5X0AAAD0AZ6kdEN/ACG9oBtcoOCFYVPj9Yn2\n",
"v/zfoFr4rWL2j9A7ZlqQHr0ZVpbLuAQJB33EyTSBNnFvVuljxMl3V6GA7Dl0BClPwL31OrTpG1l7\n",
"a7ghzL0atyS5ApCJWtp2wOBNzezTQ3N+Y1tH+luIT/i1PP0KLgniqnzZyMrwKfZeXoYEIl7twi0H\n",
"PJVeAcAdd8vPtJ2LywfKZ3u1S3on0S/4f7cj446r85qt7SkU/lr6c/+gK5erYXiPq/kf9oXoMNwY\n",
"9h0XgCkkY0ibuAMW3BGf/tJy6AGuO11Q5hQVr9nNkIcjB8Plen8B0nqwKQkOaIEp5QYqYQAAAQkB\n",
"nqZqQ38AIr2gIG4zhxx5qcIQ9c2Osw5+uNtUP7c8wH627Nk93kOS5kJwZOUsa/GuB8LSJPcgk4rv\n",
"NNy4X5Kv65LRXZpkjxKOzss2V4BAkHf3fdjwk53/8IYs8s8oIvwVKvgR9wljv8Ag07Nf+XJo681q\n",
"NbSzOUK6bv18ql/byQhgzEpF9gyeKzBYpIes4Jq5ygJqsHenGCQnuZZGCejK/v7YZig/zrXj2vhG\n",
"gCib7VW/rlAZYnZRYtYW6jN8+34R58oAelpNik7qpp/KkHdSQspzMHjVSAa9yHgI/KVEUfAeaSTC\n",
"N1Z3u1GIF1TdZRU1zNyC6xbuAxPXtz6Ez91WiAF1zBDEIltBAAABt0Gaq0moQWyZTAgj//61KoAG\n",
"e1mYdETW3g4OxfplN37UKMHTaFqDxb+9ytAjpKDc3XnMw/MxT04D0MH+PToJ4KWEuN7AocErZRv2\n",
"Rz2GQBbpS8lS31542pk6xM8YYh0/yeF1AnMnBxO2+HilOPhojFg3EW0klIcf/AybMYAo9NSuBD9C\n",
"s4e75EU0t8atdvYkg/yfik+FMNyFYTUg/mi4EKL8VgLWVSi8mxQ1+/EWE53/+fwb7K+j+527pMW9\n",
"VCj1B/8oEXG8oxyHRw/TQGPoBS7lGz9zLwh8gXusGZBvY9Xy0pnRdJKDkZLO/YjZFLNiCRPsHTqL\n",
"i2GYmJ9itG9pRnevDN9cAKQP0fgHBe/nvlXFVK7JMen+RKub1gCuPtFfO/y6rA2fstwepz1bap4Z\n",
"wJXzTLHNbeZ6/jnjul1UTQDo+Wyv2+WNy23qAxLYAQV2nquSCySITwJSTVvg+SdePIAmj5UPClGF\n",
"OrJIf0RX1xfSrhrpF0W0EhW8ceypgG4+dXb+bPwXKBwbO3GymyW89X2WJwubd13etWWTwju8K204\n",
"+w8LWTwxqMyJaP52mExMi4W5Yjr9AyAAAAElQZ7JRRUsO/8AGBMzRwWmJT1LkI6j4SrVrkQOkpGr\n",
"7qmVB6agtU/P7NMI3vz5LIs62lee9zlMDhLgStRXRkKeHaPAGaY9hwFwZg4RZnlEijsKiC6r+GA3\n",
"jOJMGPR2G+iEvFq9JqYdk0b1d9ABTX/7oiMKav8zTfVNhhkqe32oj6u1ioYXU2U/9Y4cH3f/N9Gx\n",
"JhjbFALTGuJMdeB2a/pmxPSRSx2DhwUwXe3BT4iK5IJF2QdQUjRydlTK56i3AOElSAfT6NVqnLr8\n",
"mfbO/AiWtC7ZCdSKqLQrBheoCisxuwRDc+0Qj4IlPLBawyneGpiLaece3KMzpKTos+5YxlSYlKtg\n",
"/Me6PG+fH2sUI9B09T2Px/9ucFTXTUC5j4ELLv01D5MY2VAAAADfAZ7odEN/ACK9oCBuM4G+uLMW\n",
"L2dP1lfTvDhmlpluM7IE4yEUJKicqu4KM5OijIBGmwd/fv/FYUE8C16mNefQ0Uy/D+0+Hpx1ZFAP\n",
"3vl+5XYGW/hV3tVz6fpDmClx2VYPTKI+QsHyxc+qQa6raGV2rQAFnERDWDAoPELDpD0DBzrtQ9Gj\n",
"f1X0zbjtJNpqrwp/hRbaIrr15pQNp8wHXKVl3vyz9d+FD2rUtkJQVzj6V7XpNVWdz4mpDYH1JRGS\n",
"i2MURr0RotwXgP3Qnz/8L/EyxM0Sb/CNWw8xQFPmbCgpDwAAAOUBnupqQ38AH77opN4Quy1TZxAA\n",
"Og5d0nOlbRa1c67qPfhIW7P+8Av3GtFE0HFQCvcwO1xKybwlnguY0Nqo5bzwqVZ4m1UebapfH7JG\n",
"d9M94gSTzLBzp+7XrhnquJ9dwfh5fBCyLWBt8xSfTcJZr1HXGrAMOw+Jv+pCMMogCsMVlWbHeQuT\n",
"mD3/yuQp5lDob+9AYNdyDEIT/fV+2vxg/LuQxTIX08ne1pWMu28zMsHEcHxols+2LTEYzIWCi8BU\n",
"K3ZtJRE3rAjZxLOQ4w3m2m/D157HitClmlKcP9jJchoyWV95Jy2gAAABu0Ga70moQWyZTAgj//61\n",
"KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4i4wMKqReszSNQj5h17BpSVMT9hX+zPhBrSs6Vj7HyaE\n",
"qm6lvw7kPbwwNhW67XEllpB7/AB7Dtmc/Lsrl2N4BzMZzIFVEJCqVkWDwHz0DCyLCsDRdQx8uGEg\n",
"Ikolt9wM9AgzvQ7TxR98jTrIYP8SP9CCVhDDASOwwiUKcH0pWRrgAYwjw8Gf7OlbogYj/no1BpFx\n",
"lYglvem+TH822s9SIsjJ3EA1IN/sTGSWgAXqwMREDl6rGx1E4un7krghrGWUm+/7j4jDoGqrYrQI\n",
"g7E+ktnqOLNELPNyQd8WQ/umSuXC1xL1umwA8X5+yPqMMHEIeQL1fzz/JWAXyMH93QMSzGumbhKw\n",
"Zwg0U+25Tvu4PnK5VQHbV0zvOU2Pj+MGf/nsDxqxrqZsD9S4YY9rcTfMxz/MkkzIgfRGQF/OgLHr\n",
"joIjF7P6XCeWe+XUgCwqZQG68PRNzfXkn+zUJpMMk0jjnoYnDkQ975Dz0Z65i4o7OdZtwLEOfaoE\n",
"pB0fo5td4PyA9vYIFlRo3xi7uvrQcih7/M7KbZFgAAAA9kGfDUUVLDv/ABlUeHLsmGHl+OQZEho1\n",
"hMDtEgrgr/N3AttUVM/7crMT5dwlm5uvzGVCn6w/p670sqgr5PJ6oiWC1npINQXp4CRzsctCmXzn\n",
"Ugai5K7NbwfaQcfbZKrjzT/10H2u4nhhcuuZyNqUHfbG94mETU3kKDy9A89Il0BA9I1A+R3yjNfc\n",
"+Nz5BwP3DN+ZYjka/GHLl0y68JgPyPoe9w8jyG5IXdu2vCa+LYvH9kU234z4psgT4qxlrdkhxxyP\n",
"UJXN8nPpx6cXDiQznv0L2owqy0csZbCzUw4CVJ98G+4T1R39bjI9WT0YHLigorskW6Eh4QAAAMYB\n",
"nyx0Q38AHP005LNxTWEpiZ1J9di26t3EruDGda0AVBouFN0G1ywEJMXJZuIMxrfHCac7PtwdnQsN\n",
"5ABPxruKApfvrd4v1WFO3Cl2Zd1SOG3/r1ORn6HwtueiSFcG0RNU2EL7iLFK3PfYpxwH299J2sER\n",
"9fENVpZ0Q3jjs6HsM0edV/QB07Ofn+R5vOS4TYLqhcaZAnuosw5RlS5g1Q8CuW9BZXMHWP4TGLry\n",
"nY5Y9ez3m8FrqVUEclyyvuywjGI3odTE+j8AAABPAZ8uakN/ABxO1hwOVZzdUT/5uAde8ARk7MlV\n",
"yrPTe1pSPQMTQCpdw5z/lBFmnGZwxWyqh+3IqkDkhpoxeW8ZCVdNB2x/1RnvvpDhcO3MwQAAAbZB\n",
"mzNJqEFsmUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWOrIj7MzWLMA+5yFNFLu1hTu\n",
"dlbGlkD8jL3ONezhs0gurnHp2pFLsP3djo3BgKHcLr5q4kg5WMX28rT11jnIH4bHAuJDI0/Gub5+\n",
"542H8l9OurnbLu7ccDaau7k+AVcLYmIJfjhEaissSRpn2usY/14Z8WeJwbzUwclx5b0pufbMDj2m\n",
"E4jonmtfVQvsVKXSLVBGus9F0XUey7wsw1/Hxpa1Dj6X89JFMTZZDEgLc8SXNlb52uC+3SYuA3pO\n",
"yIZ3zYRDkwb5/sIpC9s/jtT+DR4JrFHAg/zOLQvdBHh2BZ/H88Qk1FOi1nkBwtogVwTsAvTRwaaM\n",
"L+Fy6Vw65xxtt2p06IrGo+vGB6Ev7rBsQ1lA5dJTwIES1/HSnI96cCqyJNRkq8io7XoKHq1jP8jJ\n",
"K8KCILcbnjTzWMILhY3EuZ8pRzEGblkg+ofcWDech+PkwDbk4flJvQ1eVGNBBbzkH58MbHNkp5C1\n",
"pRDfsnIb9VIwGZIgexRK5GP0EM8ZveKhcNpqg0C7EdFVGM7dDkwAAAFMQZ9RRRUsO/8AFKVU3AQX\n",
"TKYCKlUskM896ABcbpuBaq23+VbIBAleYM+Uh2fmC8hKxXufvA+Jyd8ERfcMKq2QBuOeaw8cG8nv\n",
"l00dW9FnZ2ewlISmCmZ99L0bw0GXPORXq89pSQ+5zLmGTJWLpbqXg/Gg/k26eFQ7yctp0OrjpANw\n",
"gpKfTmSwqfpdIyAO4i1HmWAczC/dxtyvK6EJns7ev/M+uhg/UBsLPdCc4ktjYaoFvgpYJl8v+SaB\n",
"iW6/qJFs8B7ABY+Xoa/3pJdDPx7Wo16RIr9F0VKx7gY2CroKhVZyesK3QK039pTJworswqeMoYtQ\n",
"SxUGWdIlnZAh/LxAqJSAgdbCea7vV7Jw7UJ3RZWLCaN03DO0g6FTEO0PNlB/y2w2d5hCS2yZtMLR\n",
"726poAjDu+5lgVHjodzIR1vHcKS57NpFhydymmBuCPgAAAD3AZ9wdEN/AB0F8+qoYAk/JkWPAABe\n",
"eS/K4R2z8W8rEZ4Es2dHO2B1xqZeWERk/2j9D35SD32hnizfkl5AQkKu7sKMRtxB0qUTg/5Ai8ci\n",
"ewPsEvh0cTnE+UnVVZQsy2FhpSkguxSgj2GzhV7H4B4oQdASRatW+4ge9XWWDwbNzKDfs2ikSZGn\n",
"ZK2J2cdk5ZNdF/NbhHS0c6vDp3S53pob/1OoP8UOX13YMuZJYtnSstfaINj9HWvrLOMusuMgy0ge\n",
"hr00WpqM4G4LNFMeeHMWs3VdDioqjp1BlI0pyKTUMl2eH+Urm0ENGx6u7gM90gDkOBdN7tgm4QAA\n",
"ASkBn3JqQ38AIb2gdaiDraOcmiaTmEkpG5LCwpD7mwoBhbYx9hK/huA/Rlz76MMOi96iXfBz3DSh\n",
"vG5XYVehGnggzBAkHfGgYDsO5F3SWLpvAiWuQYgw379rpdMwhqWoBgIHHe7UqoU3PiKCUX8CUwon\n",
"PUuq8JY4AYYztu7mmGelokJyoAJS97RU/X6H+RdsNNzitkC1d8I6jDPIy7qqN4tCnL3rY6Yesfv1\n",
"e8kTaN9S190RCoZyxCFd2JzsfgZhniY0nZmfUb/Ilr3HhSfAoNjT9YPJpZU0gCEN/XEjzBiwlPnv\n",
"oPqWZP16sXNdepP+5XR/WuewqnrAjpV8x4yn9rFVK/AamriL1xzzEUk66pD3JF3R2TNlp/oPgGf2\n",
"3Zht7rWDs3F41xpI2UAAAAHmQZt3SahBbJlMCCP//rUqgAZb4qaPpZ+oz92BGWQCHjdoHQHQRvdw\n",
"JuWMeCAf9SCNq3pRzo+QLWwm+zJnwkwndhEvWHQ/SujctvY5pe+lS1QEjQXzeizSF8k6tO14eAtl\n",
"F+Mync2FH/YIAKwBXgDqn6AXOHpWQcynHtaJryxWYm270/11pJpJLJP1UcyORiPI54DPlbzdu+l/\n",
"jiFd4hpdaoZTSIPUh6A6ClqPxEqekFrNjAxud2WiOSd4IE7Kaf//vpwZ0mh9bmck4Z3rAu3/6Cvy\n",
"KA3WyoqAFX4UT0ZjH4z6LrUYRBEZElMEZc4snCHRyZf+tjKnoDXWOrVFpzxu69dV7GJ+V1irRKox\n",
"Pd1LRXYUoYi+P14fumR2pYbtX+VBW+m+c7NAd8Z01d3TTKV7Mg7nTZdtCA/oFcETl7++5b2EIheP\n",
"k2Fg+5ToPyynpqzSsvv9vWMyfYTJnDg6PojbFsxSs0nRUvqnP5QCdr6QHBhWXFOG60F0RsLzEsNc\n",
"wpNcPfKeYjjdCfe8YUIVjq0PBSvcnC+B/ETQWaX7IFbWhPaknWILlx3KsiYwYSMVn5rwfQd4Jkdd\n",
"9H+fdht5f/EJHYCK5IGupAjPxHpu+QiB/iUSmCHkkTiMqsG8twzlljjsl22n8veAAAABCEGflUUV\n",
"LDv/ABi1iDy3ZgloGXmZPsuhVsylb+qqNi7GSIfQ+OHuoRwObuWCiDJsleSNbQz9VgmS3f493Q1l\n",
"fk0LSjQ0QBKQCe3UmCkV8vYYHcKN9CZn1L0i/3IstLHQcy91VMXucG0IQjYMvd5K4nw1TsRQ+zNt\n",
"c33OM7wT4gTiFbFnfUP6sORkbyxKD8+9VWHRCKkGnoAnjqhwkHV3YzaNKz290rB0XwxFDvsi8iqf\n",
"z+DNrf49LxpvDCniJY8b921MDAhjoaXQisEELwuIkEG2MG16iA+xn4KZIc8cifkUnLKYTAHTEosc\n",
"/geFGHZmG9d/0Ad4ehB1+UFj3eeT8gc12jWX2ySdSQAAAUIBn7R0Q38AHbXz6qhgDdTYSzAi1h3K\n",
"16Xr3JTVUajJdHP4n1zwK/61yxZ9pP4QSRtJbkJZWH6vivN5vckWYfjVoaQoNcq3qWx+bI+OTtrh\n",
"UNznJnNVmMngQpK+748FuR69zyCunCVVntkmuIrtQvOCVbqBuRz5Qxvz7t49H+VL6IAp+Rh2gf74\n",
"0j/UPUfosZ/ElbvCMu7rvOP7cWI+JN6KUOE+/AXQCyHGSkSvvSc5FsX0fFal2fQXaEkH67EHfCc5\n",
"xhdseiByl+PiqAs8A9zuy4qmXDeeIj+3Yojnw30fZXbmjymzKitBenCylofDP0QjYedpgwNVFWxv\n",
"pKDrpf57i5C5JHBxrkMOZNs3TkoKjfQLvKDT/j1Fvw02tHitRU1MR1mnPja0zhtM0e5b68dpKMZ6\n",
"9AO+761c+Ba/40Js4HhAAAABBwGftmpDfwAiuitca5eBLMHeP5uZuF9cX0/VXhqHcuiBABGdnZlB\n",
"vvbdh+1A3f4uQyVZizhw70/9zDh2nx3tQGn11M/7g3e0ETDcFJMpuy3pyqZj8OhCsFXcJg/Dg2Ky\n",
"wNn+F0Nd65xqPmrT4IAWVNyWgNuyHhWrg80hH2qe3n3QFTH+AG0t1LUQWRwdt8cDbAi+8IGZZrTn\n",
"QzKAGB5g+jkMrZS2t5af/14Dikh/TUO9x6vp3udUZwfEqX9x43nyKd2KkcrjEt0VxTQ1LHt4TKTU\n",
"ov9g2wymXIrIg/m2cGScMEoY8xa4E2v0IBu8Siv364Oh7cF3cjWG+ZJkZ6xGCUsmpmsJt4n9AAAB\n",
"cEGbu0moQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2P5yKT5EQsPLolJYuDn1q5ANTN\n",
"SJwpmVcvZVK2Tco4v2Comd7hwZPuuXhX+lvh+l6ZtjrC3czf1ZVbdumb3r3D/ioYe7qcFNf7aS5r\n",
"2YnlPFx/ox3Po4uR9L227Pa5JPu/JVHojzbyIvC2hUPLYoK3yo8EFTOEx9VW2Kka/dDqBAClQEXM\n",
"coaHOVrqvWOBlx0SmrR2Fn5qD0ttjA+wKyG9Ww/+/fxdGsIy8lThxbGnpYEDoqIDxAPPdyC1j/7C\n",
"x1S6SZ6cX8TWD+edELbCVScHr4twowGayNRkN1sGJ3ChzFZqefnm592USWq1KVPalCkn+IgAbkI0\n",
"gf8crEnxuQcz5L3ov1loEzryk4ptgt40vN/cUUrwi49uNdXDzDlba6ntBbOYIPKYQqVbRsWX//V3\n",
"7VjjZzb0fU2VitbTbNlERmPP5obsCvIRmiOfAAAA7EGf2UUVLDv/ABgUnqVjfMZzbMROTbEr98Ov\n",
"G6hTv8LwbEOVBTuoZFwTL9eOUuW51yt7Pk5XoOwvCITHjPxM0+ACPLC5p8LXGPLXOMFwxyKNAOm2\n",
"+bVnL7eC/eonqWYHV7ElnGiaPE4DZvhksvIAUMvT1hgYsLWg5pHxPTMEf4vPc7k/U4gx+qn0dLIb\n",
"xLE6WPqhOli4SJOCHhekKlwgxlnM6S8wIxjTrZQVP6tyjUXc7nRDpn5+4xHTB5JTQd/Y+v5uYYim\n",
"vSxL9Lp9+sJa/YqUqQ0UFcQR3Tlp/PCrTJ5gUcQmlTDSjEV8pdpwAAABAgGf+HRDfwAhujWPq7Ze\n",
"gCJPvLBRhSSbcG6El3BFXKqbl3V6+XLJCsWmxwO7Xskzh85D3/GGBbxCjXU3okqTeEYfyjkOl+SH\n",
"4VGFs6uGeBXI6FuyUdCktochZVIQW+D6bukSQtQ9xBoZWqRH4hlWFBiT6bV+GQGerlgKyeaNsqD5\n",
"s+IDfM/wce0dikHUV0++Nr2rHe3jcRRrSy2FHjFSMdnyldmaj1iFauYYGv6d3l/8LPJtc5g5u4Q0\n",
"WerxF6DQAN+WlQUAod5dWuqnUKOySujKDQh4Sh1bNoaribkhCngsbjiJUpnyDzJfWcRyF47YB87L\n",
"Omkfy8ijCTvweGsJYAgScQAAAQUBn/pqQ38AIbo1j6vClDiF6mIvKX7IDWIXdy1QyeJm7hwAhKrN\n",
"5ZQTH6lrtJ9D3xtslHyvy2ywnd5a5/owLJHRc2EtkPadJ8Uji+G9O7CT6ooBM3rAgAWaKgWADHof\n",
"Rk55HzZ+V8DMw4S4pnRLudTRFnX1DyLXHV3VXMnhAeP+ewFDtdkUHGMhcSI0U8KajX0wWNdBGeGb\n",
"D8Ns9BH8mxfhSu/SqyYkA2AIdaTRVyL0w7XOVFH3DXljVqrcwMdXPvGgiBcw6chMaLbepo7nSmh1\n",
"vAbwAQYruBhNTN0eawky0jofbme4HocI40c1sz31wjy2n2/uelK4XikXYFYmVtl4Kdutz8YAAAGb\n",
"QZv9SahBbJlMFEwR//61KoAGK1mYIKmbNbOjB+hVE+vOJ4Z3vMpGSn/PYftL6FXoKU3FZYLBaEus\n",
"cDU8hX8r/T4sCEjN2tKC+to/+IoDOzT/F3qpjao2Qnfg6SHJn87cmSTE3IR8bzvTmr+Ye4Ac/+hl\n",
"xYNmjmRG01XaPV08JLNnbV2zuL5cn/7CsR7I4pKAadGKE6UheVLfqn0i791ThTaaO2OCRjsSWF8e\n",
"1o7SXLcWHdmh1WCFSlfjet1S/FkIphxf8M1ZQjLPF96/W7wlOpiP6jEis8o6251YpmdqxS3VSmv/\n",
"s9Bv3ISLvkMspiZj+iQwr28MINay/7syEY2A7ZiKqNUJX069yti8CuYwd1gGvQZSlufV+auVaTNU\n",
"xocXs0XuFW0e/AWENf2i3yxrLFTHW9CCBeoKH21CafAHq6hi+H/e9DkZU77nSidgvmP6DIx/XjI4\n",
"Sp9anaBxYwcylzQtEH2XN+nrwpDPp45KYG9LI0xieadJ2QOTHIvADfNhP/PY2gqE0NQ2qkvQc0a7\n",
"Xw6JCi5LfZz745MNAAAA8QGeHGpDfwAdo0DVwAgarNdw1dyEo22Z+2voCmn3MepWOJpNH9uE22Fc\n",
"UAf4fo25DS3VGYdH0kZ3bYGxdzd+R7awrh1yiW2ItRU9+fbZ+7eJ43X/1GQK2tLeuYX+rXNnNYVn\n",
"3JiyKGKiuk48G4gEpBGTo6LBxeBZg0OXhUHfR3yB3h9X56ir+g4EbNusZoLNQh23BaGzc9/s1PO9\n",
"1PPSEqrUiAosSTAygJNCJGqMs5yCqcS+EZopY3ntHhRp/rTMQhL4aAxAb8XQkEJtEmWrzD4p1eX6\n",
"QEZh/6hTVX/Gz191R2H/Dtkpg79J3GkssFm0vPkAAAH+QZoBSeEKUmUwII///rUqgAWb5x2D2a6r\n",
"t0Z9OpYFG2tABdnWLgsFoSkhKeOGdpZQLTxZJNtdR1o3VEUaCsJe7TDcWLiNBjbFk4iCHCNTwP1B\n",
"ET8aIdy/mqBaPrTdtuT/6FMRex7yXV0X/b0t3IdDKZDeFLpQzjHVkdbvbm3BNwCciVQUNcJ7Sjbw\n",
"T4hbhPp0oEDMMYqhG0FXqi8cqsDNhwZenV4L974lIjS1k1BRVCVuxIwrhHZ+ZNeKQOVccqtyU7fb\n",
"1nmmkdbnAEav9V5tnQTxoYHQvrZLL4f7C+LE0IOtSnKggNbex2Xp0FNi9T/+fjTgmF5bW9OJ+WCx\n",
"leyLvNiQF8k0bwSPMh7702+7OB9yXypsT0VFN+3fNlolLg4yJ7ye2ijeDcs0TyR0KI9OqHHwk9VT\n",
"lv0R4DjKMuNtxv3yyDdQ02ld84rRe/IbVoqtujoBlwArv27SRkTybmrwQddynU1vfFNgJ2tkTxsX\n",
"EuhAyTUDk1pdyrePvO3Kyjq07E+ZdqW1unVDCL0p2PAM0Bdj+ozOm4QJPGRq3YEQjJpnk1BNx6E0\n",
"yZMxRvkyW2tYZosgoDR8rW5jEN/sH3PsICgk/jLYhgpsvFfXxjf0NPxMCt81bgYfKxBAoUrGuF/8\n",
"Gb453zLMx96NgDfHj/3/yVULmADuEWX3e7X8vwCYAAAA+kGeP0U0TDv/ABZZvB5hJcOQiwautVBH\n",
"s3zn8C7pn/fvWkU93yxomewKAdw+9VXghKzj8nMy4EQ6n26QhvOvN3ZOGl4wrl9GlrTzWwgssqXz\n",
"oLBd9XVA4LrC7D/kDb3CEAYvcHCWxuhsk3WHFeLlRhwB95RghbDR4boSp+CQz3CY9L8bxC9Ohf/r\n",
"dy9+xoLX1H7kyaZJ3YehTdM+5Wu6Hpc4XocPo/ogFns0WlfgVPekkiZdh228q3p+OFEAyCsprsbc\n",
"bh4x6zwYau0C11ECccZga0PS18ku4j08dAfMYirHksImmVD9Aw8yto6D9YLwntF8IaA+FPG9VagA\n",
"AAB+AZ5edEN/AB8T3aVQEVcYwT0kXXzzDP4yP2lC7bONTcb6acU9HQ87UdrkSLI4+OHKFlU0EAFz\n",
"P/GPhcZ5NOIVfnz6vsVd3DH3XZLg43PF1cMypwOcG8sbzfthjMA4FQSgVvJe40X2MhECJet9t2G/\n",
"XdWa+YBzkUuLdbRPBeGTAAABAAGeQGpDfwAfDtYNiNYeWLJ1JGi8AHLac8oZrJR5tDRFy80bn36g\n",
"01RfxVuWBDFeUQUU4VHoswV2zHbq6MzAloc0SM3f88f/qXApn5tj32GTO8MmdjG+5h2BlZLr7lVk\n",
"BcTdEueULRCVgGF4dFB9PX4Y3jYyGQfKH/BWnAEfbs4hEQ8ebrGB8mSRpcKz5q1oNG7pkp8qNfsq\n",
"nkhG1h5qVJ826dklpNvhQDQdQnVi0zusZWH7g9GItx1/0euTzo8U/z7D4DrbASMUmgB0DC8TSqJd\n",
"xZ+UMAYbubxMdW+iPv2N1tIKXHdcOVBHhDDt1MeY4rBQavQwdjpFZBiUMt5ya+AAAAH0QZpFSahB\n",
"aJlMCCH//qpVAAyk+dgiPwCMdRFSufgoxGSIR+/0rSe9Cp9hy8WpEfkfjpu1RSHWd3zlulcFC+Nh\n",
"XPR//hjTft5KlTxkfWUjrzSX8Q8sCzZTRHqzVvb/rscPsXHQf0E6taB/yJXWDm9ZR5fbjX3mwQRc\n",
"72p/7Nk/lJUO//4LM1qLgtlckFFvGA4aviZYHpBb9w1OJg/Jqwvkkixar7ua0LNG3ane8+4yu/5g\n",
"n8krsqxREhrpsaI39b317zkKj6KVaeKiNvQ1KBsts5QsX+yTO1tzmbv5PRxGS8tz2hKf4zB8fbWM\n",
"XhqB6Gi2mMVEo6jXnv5vErjT3e551EcovqLpcSnuFBTI4jT6V7ZqZq5zqsmn23ZqFTbBnXJfy5qg\n",
"Xc1RIbUSG7SAPcicWIbuNtZ4GQS+WKAEZUxr++6VPQD3gpW4BeKCxEy910wCA11VXaqCgcSgS5FA\n",
"dwACIPfrp0NhEyPCvA4qNFC9NitDM1I8HthEGAjfRL6imFuJfW4+Sk08ZcO8JNBK0/bkkNG7XFo7\n",
"Hs15nZek/o+FGsRiwki6FYqc1HBc8skTelrrFiYgicL9M/ehriAlP3GGSVQdD58oSyTAbR/XOwHh\n",
"/k7736bu5rnUg2SpAi/FdrWUFq0zx+C7UUDgbK+SgABs/nsA2PEAAAE5QZ5jRREsO/8AGLWLoCDh\n",
"swFlTBepUmfLHY9h6nZJebQZXCAk5QrW0LEqJOc6Tf3RfmBa+BH+trXpxDsoWsYBGGxFB6vHNSw7\n",
"QTuHxSINvJ7kINONdsnA7unyZfe+/dUQpBab4cd9DfyyBJrHeEf61R0Nfn0RkLu3bt6BWIYQlYtM\n",
"K9Nfs/vIPwJSfjpXcON5DPtNNDffXZk4RydlgN+S/E7EUmDtA6DaeTT9v6cz5zUd9DSGZ32drbmv\n",
"ejyP/MmN69TJZPy1fo/BndGgtSNNbFsKVeTDjxqdcz9cfjIrJ3P86/aSSTu++gY85cN7L+QFkn5k\n",
"/lX20+90kKxSs6X+x+u7me+jslyG1ZQaBGKwDx+RwViiPwDARZocg2yGxzRByDsEM59E93SHlUl9\n",
"GT+PqBiUfn848MoVbAAAAOoBnoJ0Q38AIb2gG1yg4IVhU+P1ifa//N+gWvitYvaP0DtmWpAevRlW\n",
"lsu4BAkHfcTJNIE2cW9WUS6DTP7xEfhthE/Au7/XTkrYH5bPnHuWMD+L4E2Ys7TDv/WnXsb8WMjs\n",
"GVKLefmxcZqtW10iMABVusPZiYCVoxR1g16JAWeZ7iIjTKxZ0g1yWUY7SYbSh6LLTrvWvhE7lU5U\n",
"CdpswEmIpPdhoFfYojayY1ypJuWbbU1PB5nvwD9t85tVUeFQcQm5aN4kQawNooLXHpvRUW63Gqd8\n",
"iY0WiZheEXu2JHmP8XM7t/dfyrk3Fx0AAAEdAZ6EakN/ACK9oCBuM4cceanCEPXT9ZV29ukUDhUK\n",
"Q43qY97tIKPQ4ZLk+xSOgxBfQxL7yIrZscfkKmKCSoYxQfZ+tSzvOZ1GhW2ifFuVzAIEg7+77ixc\n",
"Kx//CGLPLPJ464HVUHGkhcx37PQ+kbQrXlUbN3cWUp0Qf4LtEibFhZ+LpSZJ4udEDKi6Q/S18Psl\n",
"/qmdcccWROb1W4f/Xy9V+lMS0Du/XhxzsIhWccm/rlAZXG9J5NMLdRfS734QHwqLqFpe0KPTU/Mz\n",
"iY1ev2MPDzHxs95uiDK6gRc1gvD7TgXhVki57ReTigwP0Vcnsm9mMNHj3Nt6/RMhlMwCLQhy6qqL\n",
"YC7Z58RnNbEutfWZAa9Y2SYIcplB+x/e/c7TAAABlkGaiUmoQWyZTAh3//6plgAykehDX8oAigHL\n",
"uS7e5BiYpAhLP0Zp72qQ9WFfih2hD6ViubvwAAAy+5vuYY1yi1tJuPfBi/DL0xvClymIwqUp5EK2\n",
"pijOf291KPaqRN5kbJjB/2wfKr1+XMiKLX6DysREeFfQlDwQLBucvt+vNOXQokOSOb4yTYfCyIZ/\n",
"GHqmX89FI8GoC7SVJ8dqrGOCOpcjHfvSY2QsrqBh9dhAV5Sl9v/BQKeopbgb9Qoepn/uEMh2fyEW\n",
"JmX+JgRFJalJclAgIlVBNaF+FoinY0YPKhqMcuoH+rtaEk2LTWu4NHdn9ysTAkHlBR2G+58hU289\n",
"8X49s9CJy7d2oeKmsapTwnIxxJ2LNCm+TxMniHit0ZHqI5VMxQ+5ZJ2tPHM7/cT3gdae3yVR8+YM\n",
"/KU5H6oISvxSd8TybIcXMyYVHn6O+gwy4SKx3AkMYLFpKRIO1eI3ZmEPll+L/2Ahp3aDBQxulIlY\n",
"Qc1v4+BSAHSjYxY/VpZwrkFkWgmuXijX9pnceU+eCQb0BkKKYYEAAAE7QZ6nRRUsO/8AGBMrmVrk\n",
"4p6lyIABr6JcUvWGXYV0DKg9NQWqfn9mmEaqxk2L7hAoVLAefd2AT3uOnaK6MhbcdSJ0jbOgAdky\n",
"1NCtoTFYEK1L3oNAlJW78V3WE6NttmJ67HTQFhc7jbPt6n2fAdknrF4tehh2ttPPRj0ZMNDck2O/\n",
"Og/0bAxzaaL7DSYz/qGCfH6ue/8E9mejEEqzP8HffVv8Obhn2u8eQxOotWj4hO+DblITeYVYJXny\n",
"h4Mo9PoOPQCtWY4pEEbVZmokYfc6NrhoTMJC8d+WVfQUp/9dQN2FtoGBhQPHEwvVbIcYhR7B4iO2\n",
"lHuM7fr8Nz2PLRQOuR4Lhle59+tgw9IpLSJGfVu5u0NIKILKM/viNoDYYuKxIDdR/J6apnFKAoah\n",
"uk9v6if+0v3ru/qsdmBBAAAA3wGexnRDfwAivaAgbjOBvrizFgvOGL0/6vuXDLpZruFaiDwd2rdX\n",
"jHVzx9p+aFelpPZGVUpD9afD05Q+ygH73y/cGcCL/wq72iJds0hr5PUpNV/aSoB5zpjnS1krIC0g\n",
"xgvcsTNLJd1aFsq1w5umkQK05c9QgDPa1eUOrMmn+/YlpdytXE6u+4FAjIpYVgn74StUYfcT8IT8\n",
"SGX5Wru0UB/4BiwZwXDYz0r2pPySvTt1TUg57ubb0S/BqMvEVZ5rArNFw0GaRO5EmmTuHjFK31Ed\n",
"ZcrudMiOWUSCfSesj44AAADfAZ7IakN/AB++6KTdC0Gg2vR2G3QAHQcu6TnStota0MGq57eEms8e\n",
"GSZ8YTYymFLgl7YZGG1YXmh3orKEBl6b97W6tU9/+wsf9/cg00EpDLAMwmuhlqrl+tcaP161PaCT\n",
"db1JjfLZ6rQlIR/u8Lq+hDMPBrZgZ6lFmsHEDUzmL1vhrC/Eg5wjH+dLR3xJpn70Bg13IMQhP99X\n",
"7a/GD8u5DFMhlFEykeU8M0AF5LVwxauGljyJ2PG9wt/W7GNjLNgsX4aFTR897+cKWdUMsr13pC8x\n",
"KjWMpGHXcQ2lKSkGzAAAAR5BmspJqEFsmUwIb//+p4QAYn+ayCPJyJ7QOf/irXuB3I7yUvrv3Wd8\n",
"OLQaJBb/+EMR1r6SAeh0um3VtQPrwYoZU0zDlMzZlECRYSRYOAqgamI/sUVWVEYaYAVab8QpucQ/\n",
"sSTh0wVtYsFYYkt/gr7uhkEpx1NPSuJ9CqWeDhMsefol+oaGZkPTooDGiCB29X8Zubhk7s13xY5c\n",
"l2KWl6cdQs8QOBu4PKBLJa04v3ctO+FHUCNJTXN7J5YnaOHn+BLPFy7A6HoUxVmuK9kB/hB9j6ln\n",
"0nykP3r6vgXJiVxtga3Ek+Zj3edZUHSAUux6bbxkCgdvPWLgxmKM0iIQ0SZS+9McjsqW/5Kw1hL5\n",
"sobdDT0GsHJ+I+IDODn9/vmRAAAGqm1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAB1MAAEA\n",
"AAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAA\n",
"AAAAAAAAAAAAAAAAAAAAAAAAAAIAAAXUdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAA\n",
"AB1MAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAGw\n",
"AAABIAAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAdTAAACAAAAQAAAAAFTG1kaWEAAAAgbWRo\n",
"ZAAAAAAAAAAAAAAAAAAAKAAAASwAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAA\n",
"VmlkZW9IYW5kbGVyAAAABPdtaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVm\n",
"AAAAAAAAAAEAAAAMdXJsIAAAAAEAAAS3c3RibAAAALNzdHNkAAAAAAAAAAEAAACjYXZjMQAAAAAA\n",
"AAABAAAAAAAAAAAAAAAAAAAAAAGwASAASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAA\n",
"AAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQAFf/hABhnZAAVrNlBsJaEAAADAAQAAAMAUDxYtlgB\n",
"AAZo6+PLIsAAAAAcdXVpZGtoQPJfJE/FujmlG88DI/MAAAAAAAAAGHN0dHMAAAAAAAAAAQAAAEsA\n",
"AAQAAAAAFHN0c3MAAAAAAAAAAQAAAAEAAAJgY3R0cwAAAAAAAABKAAAAAQAACAAAAAABAAAUAAAA\n",
"AAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABAAAAAAAgAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n",
"AQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAAB\n",
"AAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEA\n",
"AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAwAAAAAAQAA\n",
"BAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAA\n",
"AAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgA\n",
"AAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAA\n",
"AAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAA\n",
"AAEAAAwAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAA\n",
"AQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n",
"AAAIAAAAABxzdHNjAAAAAAAAAAEAAAABAAAASwAAAAEAAAFAc3RzegAAAAAAAAAAAAAASwAABs8A\n",
"AAI/AAABMQAAAGEAAAD8AAABkwAAAMcAAAEbAAABNgAAALkAAAGdAAAA/QAAAMYAAADdAAABzAAA\n",
"AQEAAADcAAAARQAAAdsAAAE9AAAA0QAAAU4AAAIZAAABBwAAAV4AAAEDAAABXgAAAPMAAAD0AAAB\n",
"CQAAAaUAAACyAAABnQAAANsAAAB3AAAA8QAAAbQAAAFGAAAA+AAAAQ0AAAG7AAABKQAAAOMAAADp\n",
"AAABvwAAAPoAAADKAAAAUwAAAboAAAFQAAAA+wAAAS0AAAHqAAABDAAAAUYAAAELAAABdAAAAPAA\n",
"AAEGAAABCQAAAZ8AAAD1AAACAgAAAP4AAACCAAABBAAAAfgAAAE9AAAA7gAAASEAAAGaAAABPwAA\n",
"AOMAAADjAAABIgAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEAAABabWV0YQAAAAAAAAAhaGRs\n",
"cgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWpdG9vAAAAHWRhdGEAAAABAAAA\n",
"AExhdmY1Ny44My4xMDA=\n",
"\"\u003e\n",
" Your browser does not support the video tag.\n",
"\u003c/video\u003e"
],
"text/plain": [
"\u003cIPython.core.display.HTML at 0x7f1286b190b8\u003e"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"import time\n",
"import traceback\n",
"import sys\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import animation as anim\n",
"import numpy as np\n",
"from IPython import display\n",
"\n",
"\n",
"@tf.autograph.experimental.do_not_convert\n",
"def render(boards):\n",
" fig = plt.figure()\n",
"\n",
" ims = []\n",
" for b in boards:\n",
" im = plt.imshow(b, interpolation='none')\n",
" im.axes.get_xaxis().set_visible(False)\n",
" im.axes.get_yaxis().set_visible(False)\n",
" ims.append([im])\n",
"\n",
" try:\n",
" ani = anim.ArtistAnimation(\n",
" fig, ims, interval=100, blit=True, repeat_delay=5000)\n",
" plt.close()\n",
"\n",
" display.display(display.HTML(ani.to_html5_video()))\n",
" except RuntimeError:\n",
" print('Coult not render animation:')\n",
" traceback.print_exc()\n",
" return 1\n",
" return 0\n",
"\n",
"\n",
"def gol_episode(board):\n",
" new_board = tf.TensorArray(tf.int32, 0, dynamic_size=True)\n",
"\n",
" for i in tf.range(len(board)):\n",
" for j in tf.range(len(board[i])):\n",
" num_neighbors = tf.reduce_sum(\n",
" board[tf.maximum(i-1, 0):tf.minimum(i+2, len(board)),\n",
" tf.maximum(j-1, 0):tf.minimum(j+2, len(board[i]))]\n",
" ) - board[i][j]\n",
" \n",
" if num_neighbors == 2:\n",
" new_cell = board[i][j]\n",
" elif num_neighbors == 3:\n",
" new_cell = 1\n",
" else:\n",
" new_cell = 0\n",
" \n",
" new_board.append(new_cell)\n",
" final_board = new_board.stack()\n",
" final_board = tf.reshape(final_board, board.shape)\n",
" return final_board\n",
" \n",
"\n",
"@tf.function(experimental_autograph_options=(\n",
" tf.autograph.experimental.Feature.EQUALITY_OPERATORS,\n",
" tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS,\n",
" tf.autograph.experimental.Feature.LISTS,\n",
" ))\n",
"def gol(initial_board):\n",
" board = initial_board\n",
" boards = tf.TensorArray(tf.int32, size=0, dynamic_size=True)\n",
"\n",
" i = 0\n",
" for i in tf.range(NUM_STEPS):\n",
" board = gol_episode(board)\n",
" boards.append(board)\n",
" boards = boards.stack()\n",
" tf.py_function(render, (boards,), (tf.int64,))\n",
" return i\n",
" \n",
"\n",
"# Gosper glider gun\n",
"# Adapted from http://www.cplusplus.com/forum/lounge/75168/\n",
"_ = 0\n",
"initial_board = tf.constant((\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n",
" ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
" ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n",
"))\n",
"initial_board = tf.pad(initial_board, ((0, 10), (0, 5)))\n",
"\n",
"_ = gol(initial_board)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7NgrSPCZxs3h"
},
"source": [
"#### Generated code"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hIGYeX0Cxs3i"
},
"outputs": [],
"source": [
"print(tf.autograph.to_code(gol.python_function))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "",
"kind": "local"
},
"name": "Simple algorithms using AutoGraph",
"provenance": [
{
"file_id": "19q8KdVF8Cb_fDd13i-WDOG_6n_QGNW5-",
"timestamp": 1528465909719
}
],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -1,53 +0,0 @@
# Description: Batching scheduling library.
load(
"//tensorflow:tensorflow.bzl",
"py_test",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "batch_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:batch_ops",
"//tensorflow/python:batch_ops_gen",
],
)
cc_library(
name = "batch_ops_kernels",
deps = [
"//tensorflow/core:batch_ops_op_lib",
"//tensorflow/core/kernels:batch_kernels",
],
alwayslink = 1,
)
py_test(
name = "batch_ops_test",
size = "small",
srcs = ["python/ops/batch_ops_test.py"],
python_version = "PY2",
shard_count = 5,
srcs_version = "PY2AND3",
tags = [
"manual",
"no_pip",
"nomac",
],
deps = [
":batch_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:script_ops",
],
)

View File

@ -1,27 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops and modules related to batch.
@@batch_function_v1
@@batch_function
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.batching.python.ops.batch_ops import batch_function
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -1,120 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations for automatic batching and unbatching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_batch_ops
# pylint: disable=unused-import
from tensorflow.python.ops.batch_ops import batch
from tensorflow.python.ops.batch_ops import batch_function
from tensorflow.python.ops.batch_ops import unbatch
# pylint: enable=unused-import
@ops.RegisterGradient("Batch")
def _BatchGrad(op, *out_grads): # pylint: disable=invalid-name
"""Gradient for batch op."""
gradients = []
for i in range(len(op.inputs)):
gradients.append(
gen_batch_ops.unbatch(
out_grads[i],
op.outputs[-2],
op.outputs[-1],
timeout_micros=op.get_attr("grad_timeout_micros"),
shared_name="batch_gradient_{}_{}".format(op.name, i)))
return gradients
@ops.RegisterGradient("Unbatch")
def _UnbatchGrad(op, grad): # pylint: disable=invalid-name
return [
gen_batch_ops.unbatch_grad(
op.inputs[0],
op.inputs[1],
grad,
op.inputs[2],
shared_name="unbatch_gradient_{}".format(op.name)), None, None
]
def batch_function_v1(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
grad_timeout_micros=60 * 1000 * 1000,
unbatch_timeout_micros=60 * 1000 * 1000,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
This is the older version of batch_function(). Please use the former instead
of this.
Args:
num_batch_threads: Number of scheduling threads for processing batches
of work. Determines the number of batches processed in parallel.
max_batch_size: Batch sizes will never be bigger than this.
batch_timeout_micros: Maximum number of microseconds to wait before
outputting an incomplete batch.
allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
grad_timeout_micros: The timeout to use for the gradient. See the
documentation of the unbatch op for more details. Defaults to 60s.
unbatch_timeout_micros: The timeout to use for unbatching. See the
documentation of the unbatch op for more details. Defaults to 60s.
max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
def decorator(f): # pylint: disable=missing-docstring
def decorated(*args):
with ops.name_scope("batch") as name:
for a in args:
if not isinstance(a, ops.Tensor):
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
batched_tensors, batch_index, id_t = gen_batch_ops.batch(
args,
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
batch_timeout_micros=batch_timeout_micros,
max_enqueued_batches=max_enqueued_batches,
allowed_batch_sizes=allowed_batch_sizes,
grad_timeout_micros=grad_timeout_micros,
shared_name=name)
outputs = f(*batched_tensors)
if isinstance(outputs, ops.Tensor):
outputs_list = [outputs]
else:
outputs_list = outputs
with ops.name_scope("unbatch") as unbatch_name:
unbatched = [
gen_batch_ops.unbatch(t, batch_index, id_t,
timeout_micros=unbatch_timeout_micros,
shared_name=unbatch_name + "/" + t.name)
for t in outputs_list]
if isinstance(outputs, ops.Tensor):
return unbatched[0]
return unbatched
return decorated
return decorator

View File

@ -1,87 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the currently experimental in-graph batch ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
def delayed_plus1(x):
"""Sleeps for 100ms then returns x+1."""
time.sleep(0.1)
return x + 1
class BatchOpsTest(test.TestCase):
"""Tests for batch_ops.{un,}batch."""
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
with self.cached_session() as sess:
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
def worker():
thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([result], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
batch_timeout_micros=36000000, grad_timeout_micros=1000000,
batching_queue="")
computation = batched[0] * batched[0]
result = batch_ops.unbatch(computation, index, id_t,
timeout_micros=1000000, shared_name="unbatch")
grad = gradients_impl.gradients(result, inp)
thread_results = []
def worker():
thread_results.extend(sess.run([grad], feed_dict={inp: [1]}))
worker_thread = threading.Thread(target=worker)
worker_thread.start()
main_results = sess.run([grad], feed_dict={inp: [2]})
worker_thread.join()
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [4])
if __name__ == "__main__":
test.main()

View File

@ -1,60 +0,0 @@
# Description:
# Contains ops for working with statistical distributions,
# particularly useful for Bayesian inference.
# APIs here are meant to evolve over time.
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
default_visibility = [
"//learning/brain/contrib/bayesflow:__subpackages__",
"//tensorflow:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "bayesflow_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:random_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "monte_carlo_test",
size = "small",
srcs = ["python/kernel_tests/monte_carlo_test.py"],
additional_deps = [
":bayesflow_py",
"//third_party/py/numpy",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python/ops/distributions",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_seed",
],
)

View File

@ -1,17 +0,0 @@
# Notice
`tf.contrib.bayesflow` has moved!
See new code at [github.com/tensorflow/probability](
https://github.com/tensorflow/probability).
Switch imports with:
```python
# old
import tensorflow as tf
tfp = tf.contrib.bayesflow
# new
import tensorflow_probability as tfp
```

View File

@ -1,36 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for representing Bayesian computation.
Use [tfp](/probability/api_docs/python/tfp) instead.
## This package provides classes for Bayesian computation with TensorFlow.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.bayesflow.python.ops import monte_carlo
# pylint: enable=unused-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'monte_carlo',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -1,19 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ops module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,260 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Monte Carlo Ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib
from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
layers = layers_lib
mc = monte_carlo_lib
class ExpectationImportanceSampleTest(test.TestCase):
def test_normal_integral_mean_and_var_correctly_estimated(self):
n = int(1e6)
with self.cached_session():
mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X].
e_x = mc.expectation_importance_sampler(
f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
# Compute E_p[X^2].
e_x2 = mc.expectation_importance_sampler(
f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x))
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
# Convergence of mean is +- 0.003 if n = 100M
# Convergence of stddev is +- 0.00001 if n = 100M
self.assertEqual(p.batch_shape, e_x.get_shape())
self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01)
self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02)
def test_multivariate_normal_prob_positive_product_of_components(self):
# Test that importance sampling can correctly estimate the probability that
# the product of components in a MultivariateNormal are > 0.
n = 1000
with self.cached_session():
p = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.], scale_diag=[1.0, 1.0])
q = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.5], scale_diag=[3., 3.])
# Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
# Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
def indicator(x):
x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)
prob = mc.expectation_importance_sampler(
f=indicator, log_p=p.log_prob, sampling_dist_q=q, n=n, seed=42)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
# Convergence is +- 0.004 if n = 100k.
self.assertEqual(p.batch_shape, prob.get_shape())
self.assertAllClose(0.5, prob.eval(), rtol=0.05)
class ExpectationImportanceSampleLogspaceTest(test.TestCase):
def test_normal_distribution_second_moment_estimated_correctly(self):
# Test the importance sampled estimate against an analytical result.
n = int(1e6)
with self.cached_session():
mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
sigma_q = constant_op.constant([1.0, 1.0], dtype=dtypes.float64)
p = normal_lib.Normal(loc=mu_p, scale=sigma_p)
q = normal_lib.Normal(loc=mu_q, scale=sigma_q)
# Compute E_p[X^2].
# Should equal [1, (2/3)^2]
log_e_x2 = mc.expectation_importance_sampler_logspace(
log_f=lambda x: math_ops.log(math_ops.square(x)),
log_p=p.log_prob,
sampling_dist_q=q,
n=n,
seed=42)
e_x2 = math_ops.exp(log_e_x2)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
self.assertEqual(p.batch_shape, e_x2.get_shape())
self.assertAllClose([1., (2 / 3.)**2], e_x2.eval(), rtol=0.02)
class GetSamplesTest(test.TestCase):
"""Test the private method 'get_samples'."""
def test_raises_if_both_z_and_n_are_none(self):
with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = None
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
_get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_not_none(self):
with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(seed=42)
n = 1
seed = None
with self.assertRaisesRegexp(ValueError, 'exactly one'):
_get_samples(dist, z, n, seed)
def test_returns_n_samples_if_n_provided(self):
with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = 10
seed = None
z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
def test_returns_z_if_z_provided(self):
with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(10, seed=42)
n = None
seed = None
z = _get_samples(dist, z, n, seed)
self.assertEqual((10,), z.get_shape())
class ExpectationTest(test.TestCase):
def test_works_correctly(self):
with self.cached_session() as sess:
x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
p = normal_lib.Normal(loc=x, scale=1.)
# We use the prefex "efx" to mean "E_p[f(X)]".
f = lambda u: u
efx_true = x
samples = p.sample(int(1e5), seed=1)
efx_reparam = mc.expectation(f, samples, p.log_prob)
efx_score = mc.expectation(f, samples, p.log_prob,
use_reparametrization=False)
[
efx_true_,
efx_reparam_,
efx_score_,
efx_true_grad_,
efx_reparam_grad_,
efx_score_grad_,
] = sess.run([
efx_true,
efx_reparam,
efx_score,
gradients_impl.gradients(efx_true, x)[0],
gradients_impl.gradients(efx_reparam, x)[0],
gradients_impl.gradients(efx_score, x)[0],
])
self.assertAllEqual(np.ones_like(efx_true_grad_), efx_true_grad_)
self.assertAllClose(efx_true_, efx_reparam_, rtol=0.005, atol=0.)
self.assertAllClose(efx_true_, efx_score_, rtol=0.005, atol=0.)
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
np.isfinite(efx_reparam_grad_))
self.assertAllEqual(np.ones_like(efx_true_grad_, dtype=np.bool),
np.isfinite(efx_score_grad_))
self.assertAllClose(efx_true_grad_, efx_reparam_grad_,
rtol=0.03, atol=0.)
# Variance is too high to be meaningful, so we'll only check those which
# converge.
self.assertAllClose(efx_true_grad_[2:-2],
efx_score_grad_[2:-2],
rtol=0.05, atol=0.)
def test_docstring_example_normal(self):
with self.cached_session() as sess:
num_draws = int(1e5)
mu_p = constant_op.constant(0.)
mu_q = constant_op.constant(1.)
p = normal_lib.Normal(loc=mu_p, scale=1.)
q = normal_lib.Normal(loc=mu_q, scale=2.)
exact_kl_normal_normal = kullback_leibler.kl_divergence(p, q)
approx_kl_normal_normal = monte_carlo_lib.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparametrization=(p.reparameterization_type
== distribution_lib.FULLY_REPARAMETERIZED))
[exact_kl_normal_normal_, approx_kl_normal_normal_] = sess.run([
exact_kl_normal_normal, approx_kl_normal_normal])
self.assertEqual(
True,
p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
self.assertAllClose(exact_kl_normal_normal_, approx_kl_normal_normal_,
rtol=0.01, atol=0.)
# Compare gradients. (Not present in `docstring`.)
gradp = lambda fp: gradients_impl.gradients(fp, mu_p)[0]
gradq = lambda fq: gradients_impl.gradients(fq, mu_q)[0]
[
gradp_exact_kl_normal_normal_,
gradq_exact_kl_normal_normal_,
gradp_approx_kl_normal_normal_,
gradq_approx_kl_normal_normal_,
] = sess.run([
gradp(exact_kl_normal_normal),
gradq(exact_kl_normal_normal),
gradp(approx_kl_normal_normal),
gradq(approx_kl_normal_normal),
])
self.assertAllClose(gradp_exact_kl_normal_normal_,
gradp_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
self.assertAllClose(gradq_exact_kl_normal_normal_,
gradq_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
if __name__ == '__main__':
test.main()

View File

@ -1,36 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Monte Carlo integration and helpers.
Use [tfp.monte_carlo](/probability/api_docs/python/tfp/monte_carlo) instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'expectation',
'expectation_importance_sampler',
'expectation_importance_sampler_logspace',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -1,374 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Monte Carlo integration and helpers.
@@expectation
@@expectation_importance_sampler
@@expectation_importance_sampler_logspace
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util import deprecation
__all__ = [
'expectation',
'expectation_importance_sampler',
'expectation_importance_sampler_logspace',
]
def expectation_importance_sampler(f,
log_p,
sampling_dist_q,
z=None,
n=None,
seed=None,
name='expectation_importance_sampler'):
r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\).
With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns
\\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\)
\\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\)
\\(= E_p[f(Z)]\\)
This integral is done in log-space with max-subtraction to better handle the
often extreme values that `f(z) p(z) / q(z)` can take on.
If `f >= 0`, it is up to 2x more efficient to exponentiate the result of
`expectation_importance_sampler_logspace` applied to `Log[f]`.
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
Args:
f: Callable mapping samples from `sampling_dist_q` to `Tensors` with shape
broadcastable to `q.batch_shape`.
For example, `f` works "just like" `q.log_prob`.
log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
sampling_dist_q: The sampling distribution.
`tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
name: A name to give this `Op`.
Returns:
The importance sampling estimate. `Tensor` with `shape` equal
to batch shape of `q`, and `dtype` = `q.dtype`.
"""
q = sampling_dist_q
with ops.name_scope(name, values=[z, n]):
z = _get_samples(q, z, n, seed)
log_p_z = log_p(z)
q_log_prob_z = q.log_prob(z)
def _importance_sampler_positive_f(log_f_z):
# Same as expectation_importance_sampler_logspace, but using Tensors
# rather than samples and functions. Allows us to sample once.
log_values = log_f_z + log_p_z - q_log_prob_z
return _logspace_mean(log_values)
# With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\),
# \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\)
# \\( = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\)
# Without incurring bias, 1 is added to each to prevent zeros in logspace.
# The logarithm is approximately linear around 1 + epsilon, so this is good
# for small values of 'z' as well.
f_z = f(z)
log_f_plus_z = math_ops.log(nn.relu(f_z) + 1.)
log_f_minus_z = math_ops.log(nn.relu(-1. * f_z) + 1.)
log_f_plus_integral = _importance_sampler_positive_f(log_f_plus_z)
log_f_minus_integral = _importance_sampler_positive_f(log_f_minus_z)
return math_ops.exp(log_f_plus_integral) - math_ops.exp(log_f_minus_integral)
def expectation_importance_sampler_logspace(
log_f,
log_p,
sampling_dist_q,
z=None,
n=None,
seed=None,
name='expectation_importance_sampler_logspace'):
r"""Importance sampling with a positive function, in log-space.
With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\),
this `Op` returns
\\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\)
\\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\)
\\(= Log[E_p[f(Z)]]\\)
This integral is done in log-space with max-subtraction to better handle the
often extreme values that `f(z) p(z) / q(z)` can take on.
In contrast to `expectation_importance_sampler`, this `Op` returns values in
log-space.
User supplies either `Tensor` of samples `z`, or number of samples to draw `n`
Args:
log_f: Callable mapping samples from `sampling_dist_q` to `Tensors` with
shape broadcastable to `q.batch_shape`.
For example, `log_f` works "just like" `sampling_dist_q.log_prob`.
log_p: Callable mapping samples from `sampling_dist_q` to `Tensors` with
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `q.log_prob`.
sampling_dist_q: The sampling distribution.
`tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
n: Integer `Tensor`. Number of samples to generate if `z` is not provided.
seed: Python integer to seed the random number generator.
name: A name to give this `Op`.
Returns:
Logarithm of the importance sampling estimate. `Tensor` with `shape` equal
to batch shape of `q`, and `dtype` = `q.dtype`.
"""
q = sampling_dist_q
with ops.name_scope(name, values=[z, n]):
z = _get_samples(q, z, n, seed)
log_values = log_f(z) + log_p(z) - q.log_prob(z)
return _logspace_mean(log_values)
def _logspace_mean(log_values):
"""Evaluate `Log[E[values]]` in a stable manner.
Args:
log_values: `Tensor` holding `Log[values]`.
Returns:
`Tensor` of same `dtype` as `log_values`, reduced across dim 0.
`Log[Mean[values]]`.
"""
# center = Max[Log[values]], with stop-gradient
# The center hopefully keep the exponentiated term small. It is canceled
# from the final result, so putting stop gradient on it will not change the
# final result. We put stop gradient on to eliminate unnecessary computation.
center = array_ops.stop_gradient(_sample_max(log_values))
# centered_values = exp{Log[values] - E[Log[values]]}
centered_values = math_ops.exp(log_values - center)
# log_mean_of_values = Log[ E[centered_values] ] + center
# = Log[ E[exp{log_values - E[log_values]}] ] + center
# = Log[E[values]] - E[log_values] + center
# = Log[E[values]]
log_mean_of_values = math_ops.log(_sample_mean(centered_values)) + center
return log_mean_of_values
@deprecation.deprecated(
'2018-10-01',
'The tf.contrib.bayesflow library has moved to '
'TensorFlow Probability (https://github.com/tensorflow/probability). '
'Use `tfp.monte_carlo.expectation` instead.',
warn_once=True)
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
This function computes the Monte-Carlo approximation of an expectation, i.e.,
\\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\)
where:
- `x_j = samples[j, ...]`,
- `log(p(samples)) = log_prob(samples)` and
- `m = prod(shape(samples)[axis])`.
Tricks: Reparameterization and Score-Gradient
When p is "reparameterized", i.e., a diffeomorphic transformation of a
parameterless distribution (e.g.,
`Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and
expectation, i.e.,
grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where
S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\).
However, if p is not reparameterized, TensorFlow's gradient will be incorrect
since the chain-rule stops at samples of non-reparameterized distributions.
(The non-differentiated result, `approx_expectation`, is the same regardless
of `use_reparametrization`.) In this circumstance using the Score-Gradient
trick results in an unbiased gradient, i.e.,
```none
grad[ E_p[f(X)] ]
= grad[ int dx p(x) f(x) ]
= int dx grad[ p(x) f(x) ]
= int dx [ p'(x) f(x) + p(x) f'(x) ]
= int dx p(x) [p'(x) / p(x) f(x) + f'(x) ]
= int dx p(x) grad[ f(x) p(x) / stop_grad[p(x)] ]
= E_p[ grad[ f(x) p(x) / stop_grad[p(x)] ] ]
```
Unless p is not reparametrized, it is usually preferable to
`use_reparametrization = True`.
Warning: users are responsible for verifying `p` is a "reparameterized"
distribution.
Example Use:
```python
import tensorflow_probability as tfp
tfd = tfp.distributions
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
num_draws = int(1e5)
p = tfd.Normal(loc=0., scale=1.)
q = tfd.Normal(loc=1., scale=2.)
exact_kl_normal_normal = tfd.kl_divergence(p, q)
# ==> 0.44314718
approx_kl_normal_normal = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparametrization=(p.reparameterization_type
== distribution.FULLY_REPARAMETERIZED))
# ==> 0.44632751
# Relative Error: <1%
# Monte-Carlo approximation of non-reparameterized distribution, e.g., Gamma.
num_draws = int(1e5)
p = ds.Gamma(concentration=1., rate=1.)
q = ds.Gamma(concentration=2., rate=3.)
exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
# ==> 0.37999129
approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparametrization=(p.reparameterization_type
== distribution.FULLY_REPARAMETERIZED))
# ==> 0.37696719
# Relative Error: <1%
# For comparing the gradients, see `monte_carlo_test.py`.
```
Note: The above example is for illustration only. To compute approximate
KL-divergence, the following is preferred:
```python
approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
f=bf.kl_reverse,
p_log_prob=q.log_prob,
q=p,
num_draws=num_draws)
```
Args:
f: Python callable which can return `f(samples)`.
samples: `Tensor` of samples used to form the Monte-Carlo approximation of
\\(E_p[f(X)]\\). A batch of samples should be indexed by `axis`
dimensions.
log_prob: Python callable which can return `log_prob(samples)`. Must
correspond to the natural-logarithm of the pdf/pmf of each sample. Only
required/used if `use_reparametrization=False`.
Default value: `None`.
use_reparametrization: Python `bool` indicating that the approximation
should use the fact that the gradient of samples is unbiased. Whether
`True` or `False`, this arg only affects the gradient of the resulting
`approx_expectation`.
Default value: `True`.
axis: The dimensions to average. If `None`, averages all
dimensions.
Default value: `0` (the left-most dimension).
keep_dims: If True, retains averaged dimensions using size `1`.
Default value: `False`.
name: A `name_scope` for operations created by this function.
Default value: `None` (which implies "expectation").
Returns:
approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation
of \\(E_p[f(X)]\\).
Raises:
ValueError: if `f` is not a Python `callable`.
ValueError: if `use_reparametrization=False` and `log_prob` is not a Python
`callable`.
"""
with ops.name_scope(name, 'expectation', [samples]):
if not callable(f):
raise ValueError('`f` must be a callable function.')
if use_reparametrization:
return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims)
else:
if not callable(log_prob):
raise ValueError('`log_prob` must be a callable function.')
stop = array_ops.stop_gradient # For readability.
x = stop(samples)
logpx = log_prob(x)
fx = f(x) # Call `f` once in case it has side-effects.
# We now rewrite f(x) so that:
# `grad[f(x)] := grad[f(x)] + f(x) * grad[logqx]`.
# To achieve this, we use a trick that
# `h(x) - stop(h(x)) == zeros_like(h(x))`
# but its gradient is grad[h(x)].
# Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence
# this trick loses no precision. For more discussion regarding the
# relevant portions of the IEEE754 standard, see the StackOverflow
# question,
# "Is there a floating point value of x, for which x-x == 0 is false?"
# http://stackoverflow.com/q/2686644
fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx).
return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims)
def _sample_mean(values):
"""Mean over sample indices. In this module this is always [0]."""
return math_ops.reduce_mean(values, axis=[0])
def _sample_max(values):
"""Max over sample indices. In this module this is always [0]."""
return math_ops.reduce_max(values, axis=[0])
def _get_samples(dist, z, n, seed):
"""Check args and return samples."""
with ops.name_scope('get_samples', values=[z, n]):
if (n is None) == (z is None):
raise ValueError(
'Must specify exactly one of arguments "n" and "z". Found: '
'n = %s, z = %s' % (n, z))
if n is not None:
return dist.sample(n, seed=seed)
else:
return ops.convert_to_tensor(z, name='z')

View File

@ -1,4 +0,0 @@
# Benchmarking Scripts
This directory tree contains a set of scripts that are useful when benchmarking
TensorFlow.

View File

@ -1,10 +0,0 @@
# Distributed Tensorflow on Google Compute Engine
The scripts in this directory automate the work to run distributed TensorFlow on
a cluster of GCE instances.
## Pre-work
Before getting started, while GPUs on GCE are in Alpha, you will need to get
your project whitelisted in order to get access. These scripts will not work
until then.

View File

@ -1,212 +0,0 @@
# Cloud Bigtable client for TensorFlow
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_custom_op_library",
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
"tf_kernel_library",
"tf_py_test",
)
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
tf_custom_op_py_library(
name = "bigtable",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
dso = [
":python/ops/_bigtable.so",
],
kernels = [
":bigtable_kernels",
":bigtable_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":bigtable_ops",
"//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python/data",
],
)
KERNEL_FILES = [
"kernels/bigtable_kernels.cc",
"kernels/bigtable_lookup_dataset_op.cc",
"kernels/bigtable_prefix_key_dataset_op.cc",
"kernels/bigtable_range_key_dataset_op.cc",
"kernels/bigtable_sample_keys_dataset_op.cc",
"kernels/bigtable_sample_key_pairs_dataset_op.cc",
"kernels/bigtable_scan_dataset_op.cc",
]
tf_custom_op_library(
name = "python/ops/_bigtable.so",
srcs = KERNEL_FILES + [
"ops/bigtable_ops.cc",
],
deps = [
":bigtable_lib_cc",
":bigtable_range_helpers",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
tf_gen_op_wrapper_py(
name = "bigtable_ops",
deps = [":bigtable_ops_op_lib"],
)
tf_gen_op_libs(
op_lib_names = [
"bigtable_ops",
"bigtable_test_ops",
],
)
tf_kernel_library(
name = "bigtable_kernels",
srcs = KERNEL_FILES,
deps = [
":bigtable_lib_cc",
":bigtable_range_helpers",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
# A library for use in the bigtable kernels.
cc_library(
name = "bigtable_lib_cc",
srcs = ["kernels/bigtable_lib.cc"],
hdrs = ["kernels/bigtable_lib.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
cc_library(
name = "bigtable_range_helpers",
srcs = ["kernels/bigtable_range_helpers.cc"],
hdrs = ["kernels/bigtable_range_helpers.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
],
)
cc_library(
name = "bigtable_test_client",
srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"@com_github_googleapis_googleapis//:bigtable_protos",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
"@com_googlesource_code_re2//:re2",
],
)
tf_cc_test(
name = "bigtable_test_client_test",
srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"],
tags = ["manual"],
deps = [
":bigtable_test_client",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
],
)
tf_cc_test(
name = "bigtable_range_helpers_test",
size = "small",
srcs = ["kernels/bigtable_range_helpers_test.cc"],
deps = [
":bigtable_range_helpers",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_gen_op_wrapper_py(
name = "bigtable_test_ops",
deps = [":bigtable_test_ops_op_lib"],
)
tf_custom_op_library(
name = "python/kernel_tests/_bigtable_test.so",
srcs = [
"kernels/test_kernels/bigtable_test_client_op.cc",
"ops/bigtable_test_ops.cc",
],
deps = [
":bigtable_lib_cc",
":bigtable_test_client",
"@com_googlesource_code_re2//:re2",
],
)
# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h
cc_library(
name = "bigtable_test_kernels",
srcs = [
"kernels/test_kernels/bigtable_test_client_op.cc",
],
copts = tf_copts(),
linkstatic = 1,
deps = [
":bigtable_lib_cc",
":bigtable_test_client",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@com_googlesource_code_re2//:re2",
],
alwayslink = 1,
)
tf_custom_op_py_library(
name = "bigtable_test_py",
dso = [
":python/kernel_tests/_bigtable_test.so",
],
kernels = [
":bigtable_test_kernels",
":bigtable_test_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":bigtable_test_ops",
],
)
tf_py_test(
name = "bigtable_ops_test",
size = "small",
srcs = ["python/kernel_tests/bigtable_ops_test.py"],
additional_deps = [
":bigtable",
":bigtable_test_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:util",
],
tags = ["manual"],
)

View File

@ -1,346 +0,0 @@
# Google Cloud Bigtable
[Cloud Bigtable](https://cloud.google.com/bigtable/) is a high
performance storage system that can store and serve training data. This contrib
package contains an experimental integration with TensorFlow.
> **Status: Highly experimental.** The current implementation is very much in
> flux. Please use at your own risk! :-)
The TensorFlow integration with Cloud Bigtable is optimized for common
TensorFlow usage and workloads. It is currently optimized for reading from Cloud
Bigtable at high speed, in particular to feed modern accelerators. For
general-purpose Cloud Bigtable
APIs, see the [official Cloud Bigtable client library documentation][clientdoc].
[clientdoc]: https://cloud.google.com/bigtable/docs/reference/libraries
## Sample Use
There are three main reading styles supported by the `BigtableTable` class:
1. **Reading keys**: Read only the row keys in a table. Keys are returned in
sorted order from the table. Most key reading operations retrieve all keys
in a contiguous range, however the `sample_keys` operation skips keys, and
operates on the whole table (and not a contiguous subset).
2. **Retrieving a row's values**: Given a row key, look up the data associated
with a defined set of columns. This operation takes advantage of Cloud
Bigtable's low-latency and excellent support for random access.
3. **Scanning ranges**: Given a contiguous range of rows retrieve both the row
key and the data associated with a fixed set of columns. This operation
takes advantage of Cloud Bigtable's high throughput scans, and is the most
efficient way to read data.
When using the Cloud Bigtable API, the workflow is:
1. Create a `BigtableClient` object.
2. Use the `BigtableClient` to create `BigtableTable` objects corresponding to
each table in the Cloud Bigtable instance you would like to access.
3. Call methods on the `BigtableTable` object to create `tf.data.Dataset`s to
retrieve data.
The following is an example for how to read all row keys with the prefix
`train-`.
```python
import tensorflow as tf
GCP_PROJECT_ID = '<FILL_ME_IN>'
BIGTABLE_INSTANCE_ID = '<FILL_ME_IN>'
BIGTABLE_TABLE_NAME = '<FILL_ME_IN>'
PREFIX = 'train-'
def main():
tf.enable_eager_execution()
client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID)
table = client.table(BIGTABLE_TABLE_NAME)
dataset = table.keys_by_prefix_dataset(PREFIX)
print('Retrieving rows:')
row_index = 0
for row_key in dataset:
print('Row key %d: %s' % (row_index, row_key))
row_index += 1
print('Finished reading data!')
if __name__ == '__main__':
main()
```
### Reading row keys
Read only the row keys in a table. Keys are returned in sorted order from the
table. Most key reading operations retrieve all keys in a contiguous range,
however the `sample_keys` operation skips keys, and operates on the whole table
(and not a contiguous subset).
There are 3 methods to retrieve row keys:
- `table.keys_by_range_dataset(start, end)`: Retrieve row keys starting with
`start`, and ending with `end`. The range is "half-open", and thus it
includes `start` if `start` is present in the table. It does not include
`end`.
- `table.keys_by_prefix_dataset(prefix)`: Retrieves all row keys that start
with `prefix`. It includes the row key `prefix` if present in the table.
- `table.sample_keys()`: Retrieves a sampling of keys from the underlying
table. This is often useful in conjunction with parallel scans.
### Reading cell values given a row key
Given a dataset producing row keys, you can use the `table.lookup_columns`
transformation to retrieve values. Example:
```python
key_dataset = tf.data.Dataset.from_tensor_slices([
'row_key_1',
'other_row_key',
'final_row_key',
])
values_dataset = key_dataset.apply(
table.lookup_columns(('my_column_family', 'column_name'),
('other_cf', 'col')))
training_data = values_dataset.map(my_parsing_function) # ...
```
### Scanning ranges
Given a contiguous range of rows retrieve both the row key and the data
associated with a fixed set of columns. Scanning is the most efficient way to
retrieve data from Cloud Bigtable and is thus a very common API for high
performance data pipelines. To construct a scanning `tf.data.Dataset` from a
`BigtableTable` object, call one of the following methods:
- `table.scan_prefix(prefix, ...)`
- `table.scan_range(start, end, ...)`
- `table.parallel_scan_prefix(prefix, ...)`
- `table.parallel_scan_range(start, end, ...)`
Aside from the specification of the contiguous range of rows, they all take the
following arguments:
- `probability`: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
A non-1 value indicates to probabilistically sample rows with the
provided probability.
- `columns`: The columns to read. (See below.)
- `**kwargs`: The columns to read. (See below.)
In addition the two parallel operations accept the following optional argument:
`num_parallel_scans` which configures the number of parallel Cloud Bigtable scan
operations to run. A reasonable default is automatically chosen for small
Cloud Bigtable clusters. If you have a large cluster, or an extremely demanding
workload, you can tune this value to optimize performance.
#### Specifying columns to read when scanning
All of the scan operations allow you to specify the column family and columns
in the same ways.
##### Using `columns`
The first way to specify the data to read is via the `columns` parameter. The
value should be a tuple (or list of tuples) of strings. The first string in the
tuple is the column family, and the second string in the tuple is the column
qualifier.
##### Using `**kwargs`
The second way to specify the data to read is via the `**kwargs` parameter,
which you can use to specify keyword arguments corresponding to the columns that
you want to read. The keyword to use is the column family name, and the argument
value should be either a string, or a tuple of strings, specifying the column
qualifiers (column names).
Although using `**kwargs` has the advantage of requiring less typing, it is not
future-proof in all cases. (If we add a new parameter to the scan functions that
has the same name as your column family, your code will break.)
##### Examples
Below are two equivalent snippets for how to specify which columns to read:
```python
ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
("cfa", "c2"),
("cfb", "c3")])
ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
```
In this example, we are reading 3 columns from a total of 2 column families.
From the `cfa` column family, we are reading columns `c1`, and `c2`. From the
second column family (`cfb`), we are reading `c3`. Both `ds1` and `ds2` will
output elements of the following types (`tf.string`, `tf.string`, `tf.string`,
`tf.string`). The first `tf.string` is the row key, the second `tf.string` is
the latest data in cell `cfa:c1`, the third corresponds to `cfa:c2`, and the
final one is `cfb:c3`.
#### Determinism when scanning
While the non-parallel scan operations are fully deterministic, the parallel
scan operations are not. If you would like to scan in parallel without losing
determinism, you can build up the `parallel_interleave` yourself. As an example,
say we wanted to scan all rows between `training_data_00000`, and
`training_data_90000`, we can use the following code snippet:
```python
table = # ...
columns = [('cf1', 'col1'), ('cf1', 'col2')]
NUM_PARALLEL_READS = # ...
ds = tf.data.Dataset.range(9).shuffle(10)
def interleave_fn(index):
# Given a starting index, create 2 strings to be the start and end
start_idx = index
end_idx = index + 1
start_idx_str = tf.as_string(start_idx * 10000, width=5, fill='0')
end_idx_str = tf.as_string(end_idx * 10000, width=5, fill='0')
start = tf.string_join(['training_data_', start_idx_str])
end = tf.string_join(['training_data_', end_idx_str])
return table.scan_range(start_idx, end_idx, columns=columns)
ds = ds.apply(tf.data.experimental.parallel_interleave(
interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1))
```
> Note: you should divide up the key range into more sub-ranges for increased
> parallelism.
## Writing to Cloud Bigtable
In order to simplify getting started, this package provides basic support for
writing data into Cloud Bigtable.
> Note: The implementation is not optimized for performance! Please consider
> using alternative frameworks such as Apache Beam / Cloud Dataflow for
> production workloads.
Below is an example for how to write a trivial dataset into Cloud Bigtable.
```python
import tensorflow as tf
GCP_PROJECT_ID = '<FILL_ME_IN>'
BIGTABLE_INSTANCE_ID = '<FILL_ME_IN>'
BIGTABLE_TABLE_NAME = '<FILL_ME_IN>'
COLUMN_FAMILY = '<FILL_ME_IN>'
COLUMN_QUALIFIER = '<FILL_ME_IN>'
def make_dataset():
"""Makes a dataset to write to Cloud Bigtable."""
return tf.data.Dataset.from_tensor_slices([
'training_data_1',
'training_data_2',
'training_data_3',
])
def make_row_key_dataset():
"""Makes a dataset of strings used for row keys.
The strings are of the form: `fake-data-` followed by a sequential counter.
For example, this dataset would contain the following elements:
- fake-data-00000001
- fake-data-00000002
- ...
- fake-data-23498103
"""
counter_dataset = tf.data.experimental.Counter()
width = 8
row_key_prefix = 'fake-data-'
ds = counter_dataset.map(lambda index: tf.as_string(index,
width=width,
fill='0'))
ds = ds.map(lambda idx_str: tf.string_join([row_key_prefix, idx_str]))
return ds
def main():
client = tf.contrib.cloud.BigtableClient(GCP_PROJECT_ID, BIGTABLE_INSTANCE_ID)
table = client.table(BIGTABLE_TABLE_NAME)
dataset = make_dataset()
index_dataset = make_row_key_dataset()
aggregate_dataset = tf.data.Dataset.zip((index_dataset, dataset))
write_op = table.write(aggregate_dataset, column_families=[COLUMN_FAMILY],
columns=[COLUMN_QUALIFIER])
with tf.Session() as sess:
print('Starting transfer.')
sess.run(write_op)
print('Transfer complete.')
if __name__ == '__main__':
main()
```
## Sample applications and architectures
While most machine learning applications are well suited by a high performance
distributed file system, there are certain applications where using Cloud
Bigtable works extremely well.
### Perfect Shuffling
Normally, training data is stored in flat files, and a combination of
(1) `tf.data.Dataset.interleave` (or `parallel_interleave`), (2)
`tf.data.Dataset.shuffle`, and (3) writing the data in an unsorted order in the
data files in the first place, provides enough randomization to ensure models
train efficiently. However, if you would like perfect shuffling, you can use
Cloud Bigtable's low-latency random access capabilities. Create a
`tf.data.Dataset` that generates the keys in a perfectly random order (or read
all the keys into memory and use a shuffle buffer sized to fit all of them for a
perfect random shuffle using `tf.data.Dataset.shuffle`), and then use
`lookup_columns` to retrieve the training data.
### Distributed Reinforcement Learning
Sophisticated reinforcement learning algorithms are commonly trained across a
distributed cluster. (See [IMPALA by DeepMind][impala].) One part of the cluster
runs self-play, while the other part of the cluster learns a new version of the
model based on the training data generated by self-play. The new model version
is then distributed to the self-play half of the cluster, and new training data
is generated to continue the cycle.
In such a configuration, because there is value in training on the freshest
examples, a storage service like Cloud Bigtable can be used to store and
serve the generated training data. When using Cloud Bigtable, there is no need
to aggregate the examples into large batch files, but the examples can instead
be written as soon as they are generated, and then retrieved at high speed.
[impala]: https://arxiv.org/abs/1802.01561
## Common Gotchas!
### gRPC Certificates
If you encounter a log line that includes the following:
```
"description":"Failed to load file", [...],
"filename":"/usr/share/grpc/roots.pem"
```
you can solve it via either of the following approaches:
* copy the [gRPC `roots.pem` file][grpcPem] to
`/usr/share/grpc/roots.pem` on your local machine, which is the default
location where gRPC will look for this file
* export the environment variable `GRPC_DEFAULT_SSL_ROOTS_FILE_PATH` to point to
the full path of the gRPC `roots.pem` file on your file system if it's in a
different location
[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem
### Permission denied errors
The TensorFlow Cloud Bigtable client will search for credentials to use in the
process's environment. It will use the first credentials it finds if multiple
are available.
- **Compute Engine**: When running on Compute Engine, the client will often use
the service account from the virtual machine's metadata service. Be sure to
authorize your Compute Engine VM to have access to the Cloud Bigtable service
when creating your VM, or [update the VM's scopes][update-vm-scopes] on a
running VM if you run into this issue.
- **Cloud TPU**: Your Cloud TPUs run with the designated Cloud TPU service
account dedicated to your GCP project. Ensure the service account has been
authorized via the Cloud Console to access your Cloud Bigtable instances.
[update-vm-scopes]: https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes

View File

@ -1,39 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cloud Bigtable Client for TensorFlow.
This contrib package allows TensorFlow to interface directly with Cloud Bigtable
for high-speed data loading.
@@BigtableClient
@@BigtableTable
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableTable
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'BigtableClient',
'BigtableTable',
]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -1,360 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
namespace {
class BigtableClientOp : public OpKernel {
public:
explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
OP_REQUIRES(ctx, !project_id_.empty(),
errors::InvalidArgument("project_id must be non-empty"));
OP_REQUIRES(ctx, !instance_id_.empty(),
errors::InvalidArgument("instance_id must be non-empty"));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_));
// If left unset by the client code, set it to a default of 100. Note: the
// cloud-cpp default of 4 concurrent connections is far too low for high
// performance streaming.
if (connection_pool_size_ == -1) {
connection_pool_size_ = 100;
}
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
&max_receive_message_size_));
// If left unset by the client code, set it to a default of 100. Note: the
// cloud-cpp default of 4 concurrent connections is far too low for high
// performance streaming.
if (max_receive_message_size_ == -1) {
max_receive_message_size_ = 1 << 24; // 16 MBytes
}
OP_REQUIRES(ctx, max_receive_message_size_ > 0,
errors::InvalidArgument("connection_pool_size must be > 0"));
}
~BigtableClientOp() override {
if (cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->Delete<BigtableClientResource>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource can have been deleted by session resets.
}
}
}
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
if (!initialized_) {
ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
BigtableClientResource* resource;
OP_REQUIRES_OK(
ctx,
mgr->LookupOrCreate<BigtableClientResource>(
cinfo_.container(), cinfo_.name(), &resource,
[this, ctx](
BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
auto client_options =
google::cloud::bigtable::ClientOptions()
.set_connection_pool_size(connection_pool_size_)
.set_data_endpoint("batch-bigtable.googleapis.com");
auto channel_args = client_options.channel_arguments();
channel_args.SetMaxReceiveMessageSize(
max_receive_message_size_);
channel_args.SetUserAgentPrefix("tensorflow");
channel_args.SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0);
channel_args.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000);
client_options.set_channel_arguments(channel_args);
std::shared_ptr<google::cloud::bigtable::DataClient> client =
google::cloud::bigtable::CreateDefaultDataClient(
project_id_, instance_id_, std::move(client_options));
*ret = new BigtableClientResource(project_id_, instance_id_,
std::move(client));
return Status::OK();
}));
core::ScopedUnref resource_cleanup(resource);
initialized_ = true;
}
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
ctx, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<BigtableClientResource>()));
}
private:
string project_id_;
string instance_id_;
int64 connection_pool_size_;
int32 max_receive_message_size_;
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
bool initialized_ GUARDED_BY(mu_) = false;
};
REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU),
BigtableClientOp);
class BigtableTableOp : public OpKernel {
public:
explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_));
OP_REQUIRES(ctx, !table_.empty(),
errors::InvalidArgument("table_name must be non-empty"));
}
~BigtableTableOp() override {
if (cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->Delete<BigtableTableResource>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource can have been deleted by session resets.
}
}
}
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
if (!initialized_) {
ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
core::RefCountPtr<BigtableClientResource> client_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
mgr->LookupOrCreate<BigtableTableResource>(
cinfo_.container(), cinfo_.name(), &resource,
[this, &client_resource](BigtableTableResource** ret) {
*ret = new BigtableTableResource(
client_resource.get(), table_);
return Status::OK();
}));
initialized_ = true;
}
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
ctx, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<BigtableTableResource>()));
}
private:
string table_; // Note: this is const after construction.
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
bool initialized_ GUARDED_BY(mu_) = false;
};
REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
BigtableTableOp);
} // namespace
namespace data {
namespace {
class ToBigtableOp : public AsyncOpKernel {
public:
explicit ToBigtableOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool(
ctx->env(), ThreadOptions(),
strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())),
/* num_threads = */ 1, /* low_latency_hint = */ false)) {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
thread_pool_->Schedule([this, ctx, done]() {
const Tensor* column_families_tensor;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->input("column_families", &column_families_tensor), done);
OP_REQUIRES_ASYNC(
ctx, column_families_tensor->dims() == 1,
errors::InvalidArgument("`column_families` must be a vector."), done);
const Tensor* columns_tensor;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done);
OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1,
errors::InvalidArgument("`columns` must be a vector."),
done);
OP_REQUIRES_ASYNC(
ctx,
columns_tensor->NumElements() ==
column_families_tensor->NumElements(),
errors::InvalidArgument("len(column_families) != len(columns)"),
done);
std::vector<string> column_families;
column_families.reserve(column_families_tensor->NumElements());
std::vector<string> columns;
columns.reserve(column_families_tensor->NumElements());
for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
column_families.push_back(column_families_tensor->flat<tstring>()(i));
columns.push_back(columns_tensor->flat<tstring>()(i));
}
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
dataset->MakeIterator(IteratorContext(ctx), "ToBigtableOpIterator",
&iterator),
done);
int64 timestamp_int;
OP_REQUIRES_OK_ASYNC(
ctx, ParseScalarArgument<int64>(ctx, "timestamp", &timestamp_int),
done);
OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1,
errors::InvalidArgument("timestamp must be >= -1"),
done);
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
do {
::google::cloud::bigtable::BulkMutation mutation;
// TODO(saeta): Make # of mutations configurable.
for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
OP_REQUIRES_OK_ASYNC(ctx,
iterator->GetNext(IteratorContext(ctx),
&components, &end_of_sequence),
done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
ctx,
CreateMutation(std::move(components), column_families, columns,
timestamp_int, &mutation),
done);
}
components.clear();
}
::google::cloud::Status mutation_status;
std::vector<::google::cloud::bigtable::FailedMutation> failures =
resource->table().BulkApply(mutation);
if (!failures.empty()) {
mutation_status = failures.front().status();
if (!mutation_status.ok()) {
LOG(ERROR) << "Failure applying mutation: "
<< mutation_status.code() << " - "
<< mutation_status.message() << ".";
}
::google::bigtable::v2::MutateRowsRequest request;
mutation.MoveTo(&request);
for (const auto& failure : failures) {
LOG(ERROR) << "Failure applying mutation on row ("
<< failure.original_index() << "): "
<< request.entries(failure.original_index()).row_key()
<< " - error: " << failure.status().message() << ".";
}
}
OP_REQUIRES_ASYNC(
ctx, failures.empty(),
errors::Unknown("Failure while writing to Cloud Bigtable: ",
mutation_status.code(), " - ",
mutation_status.message(),
"; # of mutation failures: ", failures.size(),
". See the log for the specific error details."),
done);
} while (!end_of_sequence);
done();
});
}
private:
static string SanitizeThreadSuffix(string suffix) {
string clean;
for (int i = 0; i < suffix.size(); ++i) {
const char ch = suffix[i];
if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
(ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
clean += ch;
} else {
clean += '_';
}
}
return clean;
}
Status CreateMutation(
std::vector<Tensor> tensors, const std::vector<string>& column_families,
const std::vector<string>& columns, int64 timestamp_int,
::google::cloud::bigtable::BulkMutation* bulk_mutation) {
if (tensors.size() != column_families.size() + 1) {
return errors::InvalidArgument(
"Iterator produced a set of Tensors shorter than expected");
}
::google::cloud::bigtable::SingleRowMutation mutation(
std::move(tensors[0].scalar<tstring>()()));
std::chrono::milliseconds timestamp(timestamp_int);
for (size_t i = 1; i < tensors.size(); ++i) {
if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
return errors::Internal("Output tensor ", i, " was not a scalar");
}
if (timestamp_int == -1) {
mutation.emplace_back(::google::cloud::bigtable::SetCell(
column_families[i - 1], columns[i - 1],
std::move(tensors[i].scalar<tstring>()())));
} else {
mutation.emplace_back(::google::cloud::bigtable::SetCell(
column_families[i - 1], columns[i - 1], timestamp,
std::move(tensors[i].scalar<tstring>()())));
}
}
bulk_mutation->emplace_back(std::move(mutation));
return Status::OK();
}
template <typename T>
Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
T* output) {
const Tensor* argument_t;
TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
return errors::InvalidArgument(argument_name, " must be a scalar");
}
*output = argument_t->scalar<T>()();
return Status::OK();
}
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
ToBigtableOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,79 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
namespace tensorflow {
namespace {
::tensorflow::error::Code GcpErrorCodeToTfErrorCode(
::google::cloud::StatusCode code) {
switch (code) {
case ::google::cloud::StatusCode::kOk:
return ::tensorflow::error::OK;
case ::google::cloud::StatusCode::kCancelled:
return ::tensorflow::error::CANCELLED;
case ::google::cloud::StatusCode::kUnknown:
return ::tensorflow::error::UNKNOWN;
case ::google::cloud::StatusCode::kInvalidArgument:
return ::tensorflow::error::INVALID_ARGUMENT;
case ::google::cloud::StatusCode::kDeadlineExceeded:
return ::tensorflow::error::DEADLINE_EXCEEDED;
case ::google::cloud::StatusCode::kNotFound:
return ::tensorflow::error::NOT_FOUND;
case ::google::cloud::StatusCode::kAlreadyExists:
return ::tensorflow::error::ALREADY_EXISTS;
case ::google::cloud::StatusCode::kPermissionDenied:
return ::tensorflow::error::PERMISSION_DENIED;
case ::google::cloud::StatusCode::kUnauthenticated:
return ::tensorflow::error::UNAUTHENTICATED;
case ::google::cloud::StatusCode::kResourceExhausted:
return ::tensorflow::error::RESOURCE_EXHAUSTED;
case ::google::cloud::StatusCode::kFailedPrecondition:
return ::tensorflow::error::FAILED_PRECONDITION;
case ::google::cloud::StatusCode::kAborted:
return ::tensorflow::error::ABORTED;
case ::google::cloud::StatusCode::kOutOfRange:
return ::tensorflow::error::OUT_OF_RANGE;
case ::google::cloud::StatusCode::kUnimplemented:
return ::tensorflow::error::UNIMPLEMENTED;
case ::google::cloud::StatusCode::kInternal:
return ::tensorflow::error::INTERNAL;
case ::google::cloud::StatusCode::kUnavailable:
return ::tensorflow::error::UNAVAILABLE;
case ::google::cloud::StatusCode::kDataLoss:
return ::tensorflow::error::DATA_LOSS;
}
}
} // namespace
Status GcpStatusToTfStatus(const ::google::cloud::Status& status) {
if (status.ok()) {
return Status::OK();
}
return Status(
GcpErrorCodeToTfErrorCode(status.code()),
strings::StrCat("Error reading from Cloud Bigtable: ", status.message()));
}
string RegexFromStringSet(const std::vector<tstring>& strs) {
CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty.";
std::unordered_set<tstring> uniq(strs.begin(), strs.end());
if (uniq.size() == 1) {
return *uniq.begin();
}
return absl::StrJoin(uniq, "|");
}
} // namespace tensorflow

View File

@ -1,153 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
#include "google/cloud/bigtable/data_client.h"
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
namespace tensorflow {
Status GcpStatusToTfStatus(const ::google::cloud::Status& status);
string RegexFromStringSet(const std::vector<tstring>& strs);
class BigtableClientResource : public ResourceBase {
public:
BigtableClientResource(
string project_id, string instance_id,
std::shared_ptr<google::cloud::bigtable::DataClient> client)
: project_id_(std::move(project_id)),
instance_id_(std::move(instance_id)),
client_(std::move(client)) {}
std::shared_ptr<google::cloud::bigtable::DataClient> get_client() {
return client_;
}
string DebugString() const override {
return strings::StrCat("BigtableClientResource(project_id: ", project_id_,
", instance_id: ", instance_id_, ")");
}
private:
const string project_id_;
const string instance_id_;
std::shared_ptr<google::cloud::bigtable::DataClient> client_;
};
class BigtableTableResource : public ResourceBase {
public:
BigtableTableResource(BigtableClientResource* client, string table_name)
: client_(client),
table_name_(std::move(table_name)),
table_(client->get_client(), table_name_,
google::cloud::bigtable::AlwaysRetryMutationPolicy()) {
client_->Ref();
}
~BigtableTableResource() override { client_->Unref(); }
::google::cloud::bigtable::Table& table() { return table_; }
string DebugString() const override {
return strings::StrCat(
"BigtableTableResource(client: ", client_->DebugString(),
", table: ", table_name_, ")");
}
private:
BigtableClientResource* client_; // Ownes one ref.
const string table_name_;
::google::cloud::bigtable::Table table_;
};
namespace data {
// BigtableReaderDatasetIterator is an abstract class for iterators from
// datasets that are "readers" (source datasets, not transformation datasets)
// that read from Bigtable.
template <typename Dataset>
class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
public:
explicit BigtableReaderDatasetIterator(
const typename DatasetIterator<Dataset>::Params& params)
: DatasetIterator<Dataset>(params) {}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsureIteratorInitialized());
if (iterator_ == reader_->end()) {
*end_of_sequence = true;
return Status::OK();
}
if (!*iterator_) {
return GcpStatusToTfStatus(iterator_->status());
}
*end_of_sequence = false;
google::cloud::bigtable::Row& row = **iterator_;
Status s = ParseRow(ctx, row, out_tensors);
// Ensure we always advance.
++iterator_;
return s;
}
protected:
virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0;
virtual ::google::cloud::bigtable::Filter MakeFilter() = 0;
virtual Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) = 0;
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented("RestoreInternal is currently not supported");
}
private:
Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (reader_) {
return Status::OK();
}
auto rows = MakeRowRange();
auto filter = MakeFilter();
// Note: the this in `this->dataset()` below is necessary due to namespace
// name conflicts.
reader_.reset(new ::google::cloud::bigtable::RowReader(
this->dataset()->table()->table().ReadRows(rows, filter)));
iterator_ = reader_->begin();
return Status::OK();
}
mutex mu_;
std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_);
::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_

View File

@ -1,248 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace data {
namespace {
class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
public:
using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
core::RefCountPtr<BigtableTableResource> table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
std::vector<tstring> column_families;
std::vector<tstring> columns;
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
&column_families));
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
OP_REQUIRES(
ctx, column_families.size() == columns.size(),
errors::InvalidArgument("len(columns) != len(column_families)"));
const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes;
output_shapes.reserve(num_outputs);
DataTypeVector output_types;
output_types.reserve(num_outputs);
for (uint64 i = 0; i < num_outputs; ++i) {
output_shapes.push_back({});
output_types.push_back(DT_STRING);
}
*output =
new Dataset(ctx, input, table.get(), std::move(column_families),
std::move(columns), output_types, std::move(output_shapes));
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
BigtableTableResource* table,
std::vector<tstring> column_families,
std::vector<tstring> columns,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
table_(table),
column_families_(std::move(column_families)),
columns_(std::move(columns)),
output_types_(output_types),
output_shapes_(std::move(output_shapes)),
filter_(MakeFilter(column_families_, columns_)) {
table_->Ref();
input_->Ref();
}
~Dataset() override {
table_->Unref();
input_->Unref();
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableLookup")}));
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "BigtableLookupDatasetOp::Dataset";
}
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
static ::google::cloud::bigtable::Filter MakeFilter(
const std::vector<tstring>& column_families,
const std::vector<tstring>& columns) {
string column_family_regex = RegexFromStringSet(column_families);
string column_regex = RegexFromStringSet(columns);
return ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1),
::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex),
::google::cloud::bigtable::Filter::ColumnRegex(column_regex));
}
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // Sequence requests.
std::vector<Tensor> input_tensors;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
if (*end_of_sequence) {
return Status::OK();
}
if (input_tensors.size() != 1) {
return errors::InvalidArgument(
"Upstream iterator (", dataset()->input_->DebugString(),
") did not produce a single `tf.string` `tf.Tensor`. It "
"produced ",
input_tensors.size(), " tensors.");
}
if (input_tensors[0].NumElements() == 0) {
return errors::InvalidArgument("Upstream iterator (",
dataset()->input_->DebugString(),
") return an empty set of keys.");
}
if (input_tensors[0].NumElements() == 1) {
// Single key lookup.
::google::cloud::StatusOr<
std::pair<bool, ::google::cloud::bigtable::Row>>
row = dataset()->table_->table().ReadRow(
input_tensors[0].scalar<tstring>()(), dataset()->filter_);
if (!row.ok()) {
return GcpStatusToTfStatus(row.status());
}
if (!row->first) {
return errors::DataLoss("Row key '",
input_tensors[0].scalar<tstring>()(),
"' not found.");
}
TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors));
} else {
// Batched get.
return errors::Unimplemented(
"BigtableLookupDataset doesn't yet support batched retrieval.");
}
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) {
out_tensors->reserve(dataset()->columns_.size() + 1);
Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(row_key_tensor));
if (row.cells().size() > 2 * dataset()->columns_.size()) {
LOG(WARNING) << "An excessive number of columns ("
<< row.cells().size()
<< ") were retrieved when reading row: "
<< row.row_key();
}
for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
bool found_column = false;
for (auto cell_itr = row.cells().begin();
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
if (cell_itr->family_name() == dataset()->column_families_[i] &&
tstring(cell_itr->column_qualifier()) ==
dataset()->columns_[i]) {
col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
found_column = true;
}
}
if (!found_column) {
return errors::DataLoss("Column ", dataset()->column_families_[i],
":", dataset()->columns_[i],
" not found in row: ", row.row_key());
}
out_tensors->emplace_back(std::move(col_tensor));
}
return Status::OK();
}
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;
BigtableTableResource* table_;
const std::vector<tstring> column_families_;
const std::vector<tstring> columns_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const ::google::cloud::bigtable::Filter filter_;
};
};
REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
BigtableLookupDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,121 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
namespace {
class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
tstring prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
*output = new Dataset(ctx, resource.get(), std::move(prefix));
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix)
: DatasetBase(DatasetContext(ctx)),
table_(table),
prefix_(std::move(prefix)) {
table_->Ref();
}
~Dataset() override { table_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtablePrefixKey")}));
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override {
return "BigtablePrefixKeyDatasetOp::Dataset";
}
BigtableTableResource* table() const { return table_; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: BigtableReaderDatasetIterator<Dataset>(params) {}
::google::cloud::bigtable::RowRange MakeRowRange() override {
return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_);
}
::google::cloud::bigtable::Filter MakeFilter() override {
return ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::CellsRowLimit(1),
::google::cloud::bigtable::Filter::StripValueTransformer());
}
Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) override {
Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
output_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(output_tensor));
return Status::OK();
}
};
BigtableTableResource* const table_;
const string prefix_;
};
};
REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
BigtablePrefixKeyDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,68 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
string MakePrefixEndKey(const string& prefix) {
string end = prefix;
while (true) {
if (end.empty()) {
return end;
}
++end[end.size() - 1];
if (end[end.size() - 1] == 0) {
// Handle wraparound case.
end = end.substr(0, end.size() - 1);
} else {
return end;
}
}
}
} // namespace
/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) {
string end = MakePrefixEndKey(prefix);
VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix
<< ", with end key: " << end;
return MultiModeKeyRange(std::move(prefix), std::move(end));
}
/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin,
string end) {
return MultiModeKeyRange(std::move(begin), std::move(end));
}
const string& MultiModeKeyRange::begin_key() const { return begin_; }
const string& MultiModeKeyRange::end_key() const { return end_; }
bool MultiModeKeyRange::contains_key(StringPiece key) const {
if (StringPiece(begin_) > key) {
return false;
}
if (StringPiece(end_) <= key && !end_.empty()) {
return false;
}
return true;
}
} // namespace tensorflow

View File

@ -1,68 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
#include <string>
#include <utility>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Represents a continuous range of keys defined by either a prefix or a range.
//
// Ranges are represented as "half-open", where the beginning key is included
// in the range, and the end_key is the first excluded key after the range.
//
// The range of keys can be specified either by a key prefix, or by an explicit
// begin key and end key. All methods on this class are valid no matter which
// way the range was specified.
//
// Example:
// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix");
// if (range.contains_key("myPrefixedKey")) {
// LOG(INFO) << "range from " << range.begin_key() << " to "
// << range.end_key() << "contains \"myPrefixedKey\"";
// }
// if (!range.contains_key("randomKey")) {
// LOG(INFO) << "range does not contain \"randomKey\"";
// }
// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key");
class MultiModeKeyRange {
public:
static MultiModeKeyRange FromPrefix(string prefix);
static MultiModeKeyRange FromRange(string begin, string end);
// The first valid key in the range.
const string& begin_key() const;
// The first invalid key after the valid range.
const string& end_key() const;
// Returns true if the provided key is a part of the range, false otherwise.
bool contains_key(StringPiece key) const;
private:
MultiModeKeyRange(string begin, string end)
: begin_(std::move(begin)), end_(std::move(end)) {}
const string begin_;
const string end_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_

View File

@ -1,107 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
TEST(MultiModeKeyRangeTest, SimplePrefix) {
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix");
EXPECT_EQ("prefix", r.begin_key());
EXPECT_EQ("prefiy", r.end_key());
EXPECT_TRUE(r.contains_key("prefixed_key"));
EXPECT_FALSE(r.contains_key("not-prefixed-key"));
EXPECT_FALSE(r.contains_key("prefi"));
EXPECT_FALSE(r.contains_key("prefiy"));
EXPECT_FALSE(r.contains_key("early"));
EXPECT_FALSE(r.contains_key(""));
}
TEST(MultiModeKeyRangeTest, Range) {
MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b");
EXPECT_EQ("a", r.begin_key());
EXPECT_EQ("b", r.end_key());
EXPECT_TRUE(r.contains_key("a"));
EXPECT_TRUE(r.contains_key("ab"));
EXPECT_FALSE(r.contains_key("b"));
EXPECT_FALSE(r.contains_key("bc"));
EXPECT_FALSE(r.contains_key("A"));
EXPECT_FALSE(r.contains_key("B"));
EXPECT_FALSE(r.contains_key(""));
}
TEST(MultiModeKeyRangeTest, InvertedRange) {
MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a");
EXPECT_FALSE(r.contains_key("a"));
EXPECT_FALSE(r.contains_key("b"));
EXPECT_FALSE(r.contains_key(""));
}
TEST(MultiModeKeyRangeTest, EmptyPrefix) {
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("");
EXPECT_EQ("", r.begin_key());
EXPECT_EQ("", r.end_key());
EXPECT_TRUE(r.contains_key(""));
EXPECT_TRUE(r.contains_key("a"));
EXPECT_TRUE(r.contains_key("z"));
EXPECT_TRUE(r.contains_key("A"));
EXPECT_TRUE(r.contains_key("ZZZZZZ"));
}
TEST(MultiModeKeyRangeTest, HalfRange) {
MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", "");
EXPECT_EQ("start", r.begin_key());
EXPECT_EQ("", r.end_key());
EXPECT_TRUE(r.contains_key("start"));
EXPECT_TRUE(r.contains_key("starting"));
EXPECT_TRUE(r.contains_key("z-end"));
EXPECT_FALSE(r.contains_key(""));
EXPECT_FALSE(r.contains_key("early"));
}
TEST(MultiModeKeyRangeTest, PrefixWrapAround) {
string prefix = "abc\xff";
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
EXPECT_EQ(prefix, r.begin_key());
EXPECT_EQ("abd", r.end_key());
EXPECT_TRUE(r.contains_key("abc\xff\x07"));
EXPECT_TRUE(r.contains_key("abc\xff\x15"));
EXPECT_TRUE(r.contains_key("abc\xff\x61"));
EXPECT_TRUE(r.contains_key("abc\xff\xff"));
EXPECT_FALSE(r.contains_key("abc\0"));
EXPECT_FALSE(r.contains_key("abd"));
}
TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) {
string prefix = "abc\x7f";
MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
EXPECT_EQ(prefix, r.begin_key());
EXPECT_EQ("abc\x80", r.end_key());
EXPECT_TRUE(r.contains_key("abc\x7f\x07"));
EXPECT_TRUE(r.contains_key("abc\x7f\x15"));
EXPECT_TRUE(r.contains_key("abc\x7f\x61"));
EXPECT_TRUE(r.contains_key("abc\x7f\xff"));
EXPECT_FALSE(r.contains_key("abc\0"));
EXPECT_FALSE(r.contains_key("abc\x01"));
EXPECT_FALSE(r.contains_key("abd"));
EXPECT_FALSE(r.contains_key("ab\x80"));
}
} // namespace
} // namespace tensorflow

View File

@ -1,127 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
namespace {
class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
tstring start_key;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
tstring end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
*output = new Dataset(ctx, resource.get(), std::move(start_key),
std::move(end_key));
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string start_key, string end_key)
: DatasetBase(DatasetContext(ctx)),
table_(table),
start_key_(std::move(start_key)),
end_key_(std::move(end_key)) {
table_->Ref();
}
~Dataset() override { table_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableRangeKey")}));
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override {
return "BigtableRangeKeyDatasetOp::Dataset";
}
BigtableTableResource* table() const { return table_; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: BigtableReaderDatasetIterator<Dataset>(params) {}
::google::cloud::bigtable::RowRange MakeRowRange() override {
return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_,
dataset()->end_key_);
}
::google::cloud::bigtable::Filter MakeFilter() override {
return ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::CellsRowLimit(1),
::google::cloud::bigtable::Filter::StripValueTransformer());
}
Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) override {
Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
output_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(output_tensor));
return Status::OK();
}
};
BigtableTableResource* const table_;
const string start_key_;
const string end_key_;
};
};
REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
BigtableRangeKeyDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,227 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
namespace {
class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
tstring prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
tstring start_key;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
tstring end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument(
"Only one of prefix and start_key can be provided"));
if (!prefix.empty()) {
OP_REQUIRES(ctx, end_key.empty(),
errors::InvalidArgument(
"If prefix is specified, end_key must be empty."));
}
*output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key));
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key)
: DatasetBase(DatasetContext(ctx)),
table_(table),
key_range_(MakeMultiModeKeyRange(
std::move(prefix), std::move(start_key), std::move(end_key))) {
table_->Ref();
}
~Dataset() override { table_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableSampleKeyPairs")}));
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes =
new DataTypeVector({DT_STRING, DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}, {}});
return *shapes;
}
string DebugString() const override {
return "BigtableSampleKeyPairsDatasetOp::Dataset";
}
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
string start_key,
string end_key) {
if (!start_key.empty()) {
return MultiModeKeyRange::FromRange(std::move(start_key),
std::move(end_key));
}
return MultiModeKeyRange::FromPrefix(std::move(prefix));
}
BigtableTableResource& table() const { return *table_; }
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
// Computes split points (`keys_`) to use when scanning the table.
//
// Initialize first retrieves the sample keys from the table (`row_keys`),
// as these often form good split points within the table. We then iterate
// over them, and copy them to `keys_` if they fall within the requested
// range to scan (`dataset()->key_range_`). Because the requested range
// might start between elements of the sampled keys list, care is taken to
// ensure we don't accidentally miss any subsets of the requested range by
// including `begin_key()` and `end_key()` as appropriate.
Status Initialize(IteratorContext* ctx) override {
::google::cloud::StatusOr<
std::vector<::google::cloud::bigtable::RowKeySample>>
row_key_samples = dataset()->table().table().SampleRows();
if (!row_key_samples.ok()) {
return GcpStatusToTfStatus(row_key_samples.status());
}
for (const auto& row_key_sample : *row_key_samples) {
string row_key(row_key_sample.row_key);
if (dataset()->key_range_.contains_key(row_key)) {
// First key: check to see if we need to add the begin_key.
if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
keys_.push_back(dataset()->key_range_.begin_key());
}
keys_.push_back(std::move(row_key));
} else if (!keys_.empty()) {
// If !keys_.empty(), then we have found at least one element of
// `row_keys` that is within our requested range
// (`dataset()->key_range_`). Because `row_keys` is sorted, if we
// have found an element that's not within our key range, then we
// are after our requested range (ranges are contiguous) and can end
// iteration early.
break;
}
}
// Handle the case where we skip over the selected range entirely.
if (keys_.empty()) {
keys_.push_back(dataset()->key_range_.begin_key());
}
// Last key: check to see if we need to add the end_key.
if (keys_.back() != dataset()->key_range_.end_key()) {
keys_.push_back(dataset()->key_range_.end_key());
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (index_ + 2 > keys_.size()) {
*end_of_sequence = true;
return Status::OK();
}
*end_of_sequence = false;
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
out_tensors->back().scalar<tstring>()() = keys_[index_];
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
out_tensors->back().scalar<tstring>()() = keys_[index_ + 1];
++index_;
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
mutex mu_;
size_t index_ GUARDED_BY(mu_) = 0;
// Note: we store the keys_ on the iterator instead of the dataset
// because we want to re-sample the row keys in case there have been
// tablet rebalancing operations since the dataset was created.
//
// Note: keys_ is readonly after Initialize, and thus does not need a
// guarding lock.
std::vector<string> keys_;
};
BigtableTableResource* const table_;
const MultiModeKeyRange key_range_;
};
};
REGISTER_KERNEL_BUILDER(
Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
BigtableSampleKeyPairsDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,141 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace data {
namespace {
class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
*output = new Dataset(ctx, resource.get());
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
: DatasetBase(DatasetContext(ctx)), table_(table) {
table_->Ref();
}
~Dataset() override { table_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BigtableSampleKeys")}));
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() const override {
return "BigtableRangeKeyDatasetOp::Dataset";
}
BigtableTableResource* table() const { return table_; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
::google::cloud::StatusOr<
std::vector<::google::cloud::bigtable::RowKeySample>>
sampled_rows = dataset()->table()->table().SampleRows();
if (!sampled_rows.ok()) {
row_keys_.clear();
return GcpStatusToTfStatus(sampled_rows.status());
}
row_keys_ = std::move(*sampled_rows);
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (index_ < row_keys_.size()) {
out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
TensorShape({}));
out_tensors->back().scalar<tstring>()() =
tstring(row_keys_[index_].row_key);
*end_of_sequence = false;
index_++;
} else {
*end_of_sequence = true;
}
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
mutex mu_;
size_t index_ = 0;
std::vector<::google::cloud::bigtable::RowKeySample> row_keys_;
};
BigtableTableResource* const table_;
};
};
REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
BigtableSampleKeysDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,235 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
namespace data {
namespace {
class BigtableScanDatasetOp : public DatasetOpKernel {
public:
using DatasetOpKernel::DatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
tstring prefix;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "prefix", &prefix));
tstring start_key;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<tstring>(ctx, "start_key", &start_key));
tstring end_key;
OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, "end_key", &end_key));
OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
errors::InvalidArgument(
"Either prefix or start_key must be specified"));
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument(
"Only one of prefix and start_key can be provided"));
if (!prefix.empty()) {
OP_REQUIRES(ctx, end_key.empty(),
errors::InvalidArgument(
"If prefix is specified, end_key must be empty."));
}
std::vector<tstring> column_families;
std::vector<tstring> columns;
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "column_families",
&column_families));
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, "columns", &columns));
OP_REQUIRES(
ctx, column_families.size() == columns.size(),
errors::InvalidArgument("len(columns) != len(column_families)"));
OP_REQUIRES(ctx, !column_families.empty(),
errors::InvalidArgument("`column_families` is empty"));
float probability = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<float>(ctx, "probability", &probability));
OP_REQUIRES(
ctx, probability > 0 && probability <= 1,
errors::InvalidArgument(
"Probability outside the range of (0, 1]. Got: ", probability));
core::RefCountPtr<BigtableTableResource> resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes;
output_shapes.reserve(num_outputs);
DataTypeVector output_types;
output_types.reserve(num_outputs);
for (uint64 i = 0; i < num_outputs; ++i) {
output_shapes.push_back({});
output_types.push_back(DT_STRING);
}
*output = new Dataset(ctx, resource.get(), std::move(prefix),
std::move(start_key), std::move(end_key),
std::move(column_families), std::move(columns),
probability, output_types, std::move(output_shapes));
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
string prefix, string start_key, string end_key,
std::vector<tstring> column_families,
std::vector<tstring> columns, float probability,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
: DatasetBase(DatasetContext(ctx)),
table_(table),
prefix_(std::move(prefix)),
start_key_(std::move(start_key)),
end_key_(std::move(end_key)),
column_families_(std::move(column_families)),
columns_(std::move(columns)),
column_family_regex_(RegexFromStringSet(column_families_)),
column_regex_(RegexFromStringSet(columns_)),
probability_(probability),
output_types_(output_types),
output_shapes_(std::move(output_shapes)) {
table_->Ref();
}
~Dataset() override { table_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::BigtableScan")}));
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "BigtableScanDatasetOp::Dataset";
}
BigtableTableResource* table() const { return table_; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:
class Iterator : public BigtableReaderDatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: BigtableReaderDatasetIterator<Dataset>(params) {}
::google::cloud::bigtable::RowRange MakeRowRange() override {
if (!dataset()->prefix_.empty()) {
DCHECK(dataset()->start_key_.empty());
return ::google::cloud::bigtable::RowRange::Prefix(
dataset()->prefix_);
} else {
DCHECK(!dataset()->start_key_.empty())
<< "Both prefix and start_key were empty!";
return ::google::cloud::bigtable::RowRange::Range(
dataset()->start_key_, dataset()->end_key_);
}
}
::google::cloud::bigtable::Filter MakeFilter() override {
// TODO(saeta): Investigate optimal ordering here.
return ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1),
::google::cloud::bigtable::Filter::FamilyRegex(
dataset()->column_family_regex_),
::google::cloud::bigtable::Filter::ColumnRegex(
dataset()->column_regex_),
dataset()->probability_ != 1.0
? ::google::cloud::bigtable::Filter::RowSample(
dataset()->probability_)
: ::google::cloud::bigtable::Filter::PassAllFilter());
}
Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) override {
out_tensors->reserve(dataset()->columns_.size() + 1);
Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
row_key_tensor.scalar<tstring>()() = tstring(row.row_key());
out_tensors->emplace_back(std::move(row_key_tensor));
if (row.cells().size() > 2 * dataset()->columns_.size()) {
LOG(WARNING) << "An excessive number of columns ("
<< row.cells().size()
<< ") were retrieved when reading row: "
<< row.row_key();
}
for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
bool found_column = false;
for (auto cell_itr = row.cells().begin();
!found_column && cell_itr != row.cells().end(); ++cell_itr) {
if (cell_itr->family_name() == dataset()->column_families_[i] &&
tstring(cell_itr->column_qualifier()) ==
dataset()->columns_[i]) {
col_tensor.scalar<tstring>()() = tstring(cell_itr->value());
found_column = true;
}
}
if (!found_column) {
return errors::InvalidArgument(
"Column ", dataset()->column_families_[i], ":",
dataset()->columns_[i], " not found in row: ", row.row_key());
}
out_tensors->emplace_back(std::move(col_tensor));
}
return Status::OK();
}
};
BigtableTableResource* table_;
const string prefix_;
const string start_key_;
const string end_key_;
const std::vector<tstring> column_families_;
const std::vector<tstring> columns_;
const string column_family_regex_;
const string column_regex_;
const float probability_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
};
REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
BigtableScanDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -1,462 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
#include "external/com_github_googleapis_googleapis/google/bigtable/v2/data.pb.h"
#include "google/protobuf/wrappers.pb.h"
#include "re2/re2.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/util/ptr_util.h"
// #include "util/task/codes.pb.h"
namespace tensorflow {
namespace {
void UpdateRow(const ::google::bigtable::v2::Mutation& mut,
std::map<string, string>* row) {
if (mut.has_set_cell()) {
CHECK(mut.set_cell().timestamp_micros() >= -1)
<< "Timestamp_micros: " << mut.set_cell().timestamp_micros();
auto col =
strings::Printf("%s:%s", mut.set_cell().family_name().c_str(),
string(mut.set_cell().column_qualifier()).c_str());
(*row)[col] = string(mut.set_cell().value());
} else if (mut.has_delete_from_column()) {
auto col = strings::Printf(
"%s:%s", mut.delete_from_column().family_name().c_str(),
string(mut.delete_from_column().column_qualifier()).c_str());
row->erase(col);
} else if (mut.has_delete_from_family()) {
auto itr = row->lower_bound(mut.delete_from_family().family_name());
auto prefix =
strings::Printf("%s:", mut.delete_from_family().family_name().c_str());
while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) {
row->erase(itr);
}
} else if (mut.has_delete_from_row()) {
row->clear();
} else {
LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString();
}
}
} // namespace
class SampleRowKeysResponse : public grpc::ClientReaderInterface<
google::bigtable::v2::SampleRowKeysResponse> {
public:
explicit SampleRowKeysResponse(BigtableTestClient* client)
: client_(client) {}
bool NextMessageSize(uint32_t* sz) override {
mutex_lock l(mu_);
mutex_lock l2(client_->mu_);
if (num_messages_sent_ * 2 < client_->table_.rows.size()) {
*sz = 10000; // A sufficiently high enough value to not worry about.
return true;
}
return false;
}
bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override {
// Send every other key from the table.
mutex_lock l(mu_);
mutex_lock l2(client_->mu_);
*resp = google::bigtable::v2::SampleRowKeysResponse();
auto itr = client_->table_.rows.begin();
for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) {
++itr;
if (itr == client_->table_.rows.end()) {
return false;
}
}
resp->set_row_key(itr->first);
resp->set_offset_bytes(100 * num_messages_sent_);
num_messages_sent_++;
return true;
}
grpc::Status Finish() override { return grpc::Status::OK; }
void WaitForInitialMetadata() override {} // Do nothing.
private:
mutex mu_;
int64 num_messages_sent_ GUARDED_BY(mu_) = 0;
BigtableTestClient* client_; // Not owned.
};
class ReadRowsResponse : public grpc::ClientReaderInterface<
google::bigtable::v2::ReadRowsResponse> {
public:
ReadRowsResponse(BigtableTestClient* client,
google::bigtable::v2::ReadRowsRequest const& request)
: client_(client), request_(request) {}
bool NextMessageSize(uint32_t* sz) override {
mutex_lock l(mu_);
if (sent_first_message_) {
return false;
}
*sz = 10000000; // A sufficiently high enough value to not worry about.
return true;
}
bool Read(google::bigtable::v2::ReadRowsResponse* resp) override {
mutex_lock l(mu_);
if (sent_first_message_) {
return false;
}
sent_first_message_ = true;
RowFilter filter = MakeRowFilter();
mutex_lock l2(client_->mu_);
*resp = google::bigtable::v2::ReadRowsResponse();
// Send all contents in first response.
for (auto itr = client_->table_.rows.begin();
itr != client_->table_.rows.end(); ++itr) {
if (filter.AllowRow(itr->first)) {
::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr;
bool sent_first = false;
for (auto col_itr = itr->second.columns.begin();
col_itr != itr->second.columns.end(); ++col_itr) {
if (filter.AllowColumn(col_itr->first)) {
chunk = resp->add_chunks();
if (!sent_first) {
sent_first = true;
chunk->set_row_key(itr->first);
}
auto colon_idx = col_itr->first.find(":");
CHECK(colon_idx != string::npos)
<< "No ':' found in: " << col_itr->first;
chunk->mutable_family_name()->set_value(
string(col_itr->first, 0, colon_idx));
chunk->mutable_qualifier()->set_value(
string(col_itr->first, ++colon_idx));
if (!filter.strip_values) {
chunk->set_value(col_itr->second);
}
if (filter.only_one_column) {
break;
}
}
}
if (sent_first) {
// We are sending this row, so set the commit flag on the last chunk.
chunk->set_commit_row(true);
}
}
}
return true;
}
grpc::Status Finish() override { return grpc::Status::OK; }
void WaitForInitialMetadata() override {} // Do nothing.
private:
struct RowFilter {
std::set<string> row_set;
std::vector<std::pair<string, string>> row_ranges;
double row_sample = 0.0; // Note: currently ignored.
std::unique_ptr<RE2> col_filter;
bool strip_values = false;
bool only_one_column = false;
bool AllowRow(const string& row) {
if (row_set.find(row) != row_set.end()) {
return true;
}
for (const auto& range : row_ranges) {
if (range.first <= row && range.second > row) {
return true;
}
}
return false;
}
bool AllowColumn(const string& col) {
if (col_filter) {
return RE2::FullMatch(col, *col_filter);
} else {
return true;
}
}
};
RowFilter MakeRowFilter() {
RowFilter filter;
for (auto i = request_.rows().row_keys().begin();
i != request_.rows().row_keys().end(); ++i) {
filter.row_set.insert(string(*i));
}
for (auto i = request_.rows().row_ranges().begin();
i != request_.rows().row_ranges().end(); ++i) {
if (i->start_key_case() !=
google::bigtable::v2::RowRange::kStartKeyClosed ||
i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) {
LOG(WARNING) << "Skipping row range that cannot be processed: "
<< i->ShortDebugString();
continue;
}
filter.row_ranges.emplace_back(std::make_pair(
string(i->start_key_closed()), string(i->end_key_open())));
}
if (request_.filter().has_chain()) {
string family_filter;
string qualifier_filter;
for (auto i = request_.filter().chain().filters().begin();
i != request_.filter().chain().filters().end(); ++i) {
switch (i->filter_case()) {
case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter:
family_filter = i->family_name_regex_filter();
break;
case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter:
qualifier_filter = i->column_qualifier_regex_filter();
break;
case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter:
if (i->cells_per_column_limit_filter() != 1) {
LOG(ERROR) << "Unexpected cells_per_column_limit_filter: "
<< i->cells_per_column_limit_filter();
}
break;
case google::bigtable::v2::RowFilter::kStripValueTransformer:
filter.strip_values = i->strip_value_transformer();
break;
case google::bigtable::v2::RowFilter::kRowSampleFilter:
LOG(INFO) << "Ignoring row sample directive.";
break;
case google::bigtable::v2::RowFilter::kPassAllFilter:
break;
case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter:
filter.only_one_column = true;
break;
default:
LOG(WARNING) << "Ignoring unknown filter type: "
<< i->ShortDebugString();
}
}
if (family_filter.empty() || qualifier_filter.empty()) {
LOG(WARNING) << "Missing regex!";
} else {
string regex = strings::Printf("%s:%s", family_filter.c_str(),
qualifier_filter.c_str());
filter.col_filter.reset(new RE2(regex));
}
} else {
LOG(WARNING) << "Read request did not have a filter chain specified: "
<< request_.filter().DebugString();
}
return filter;
}
mutex mu_;
bool sent_first_message_ GUARDED_BY(mu_) = false;
BigtableTestClient* client_; // Not owned.
const google::bigtable::v2::ReadRowsRequest request_;
};
class MutateRowsResponse : public grpc::ClientReaderInterface<
google::bigtable::v2::MutateRowsResponse> {
public:
explicit MutateRowsResponse(size_t num_successes)
: num_successes_(num_successes) {}
bool NextMessageSize(uint32_t* sz) override {
mutex_lock l(mu_);
if (sent_first_message_) {
return false;
}
*sz = 10000000; // A sufficiently high enough value to not worry about.
return true;
}
bool Read(google::bigtable::v2::MutateRowsResponse* resp) override {
mutex_lock l(mu_);
if (sent_first_message_) {
return false;
}
sent_first_message_ = true;
*resp = google::bigtable::v2::MutateRowsResponse();
for (size_t i = 0; i < num_successes_; ++i) {
auto entry = resp->add_entries();
entry->set_index(i);
}
return true;
}
grpc::Status Finish() override { return grpc::Status::OK; }
void WaitForInitialMetadata() override {} // Do nothing.
private:
const size_t num_successes_;
mutex mu_;
bool sent_first_message_ = false;
};
grpc::Status BigtableTestClient::MutateRow(
grpc::ClientContext* context,
google::bigtable::v2::MutateRowRequest const& request,
google::bigtable::v2::MutateRowResponse* response) {
mutex_lock l(mu_);
auto* row = &table_.rows[string(request.row_key())];
for (int i = 0; i < request.mutations_size(); ++i) {
UpdateRow(request.mutations(i), &row->columns);
}
*response = google::bigtable::v2::MutateRowResponse();
return grpc::Status::OK;
}
grpc::Status BigtableTestClient::CheckAndMutateRow(
grpc::ClientContext* context,
google::bigtable::v2::CheckAndMutateRowRequest const& request,
google::bigtable::v2::CheckAndMutateRowResponse* response) {
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
"CheckAndMutateRow not implemented.");
}
grpc::Status BigtableTestClient::ReadModifyWriteRow(
grpc::ClientContext* context,
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
google::bigtable::v2::ReadModifyWriteRowResponse* response) {
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
"ReadModifyWriteRow not implemented.");
}
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::ReadModifyWriteRowResponse>>
BigtableTestClient::AsyncReadModifyWriteRow(
grpc::ClientContext* context,
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
grpc::CompletionQueue* cq) {
LOG(WARNING) << "Call to AsyncReadModifyWriteRow:" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
BigtableTestClient::ReadRows(
grpc::ClientContext* context,
google::bigtable::v2::ReadRowsRequest const& request) {
return MakeUnique<ReadRowsResponse>(this, request);
}
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
BigtableTestClient::SampleRowKeys(
grpc::ClientContext* context,
google::bigtable::v2::SampleRowKeysRequest const& request) {
return MakeUnique<SampleRowKeysResponse>(this);
}
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
BigtableTestClient::MutateRows(
grpc::ClientContext* context,
google::bigtable::v2::MutateRowsRequest const& request) {
mutex_lock l(mu_);
for (auto i = request.entries().begin(); i != request.entries().end(); ++i) {
auto* row = &table_.rows[string(i->row_key())];
for (auto mut = i->mutations().begin(); mut != i->mutations().end();
++mut) {
UpdateRow(*mut, &row->columns);
}
}
return MakeUnique<MutateRowsResponse>(request.entries_size());
}
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::MutateRowResponse>>
BigtableTestClient::AsyncMutateRow(
grpc::ClientContext* context,
google::bigtable::v2::MutateRowRequest const& request,
grpc::CompletionQueue* cq) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::SampleRowKeysResponse>>
BigtableTestClient::AsyncSampleRowKeys(
::grpc::ClientContext* context,
const ::google::bigtable::v2::SampleRowKeysRequest& request,
::grpc::CompletionQueue* cq, void* tag) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::MutateRowsResponse>>
BigtableTestClient::AsyncMutateRows(
::grpc::ClientContext* context,
const ::google::bigtable::v2::MutateRowsRequest& request,
::grpc::CompletionQueue* cq, void* tag) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::CheckAndMutateRowResponse>>
BigtableTestClient::AsyncCheckAndMutateRow(
grpc::ClientContext* context,
const google::bigtable::v2::CheckAndMutateRowRequest& request,
grpc::CompletionQueue* cq) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<
grpc::ClientAsyncReaderInterface<google::bigtable::v2::ReadRowsResponse>>
BigtableTestClient::AsyncReadRows(
grpc::ClientContext* context,
const google::bigtable::v2::ReadRowsRequest& request,
grpc::CompletionQueue* cq, void* tag) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<
grpc::ClientAsyncReaderInterface<google::bigtable::v2::MutateRowsResponse>>
BigtableTestClient::PrepareAsyncMutateRows(
grpc::ClientContext* context,
const google::bigtable::v2::MutateRowsRequest& request,
grpc::CompletionQueue* cq) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::ReadRowsResponse>>
BigtableTestClient::PrepareAsyncReadRows(
::grpc::ClientContext* context,
const ::google::bigtable::v2::ReadRowsRequest& request,
::grpc::CompletionQueue* cq) {
LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
<< "(); this will likely cause a crash!";
return nullptr;
}
std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() {
LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely "
"cause a crash!";
return nullptr;
}
} // namespace tensorflow

View File

@ -1,138 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
#include "google/cloud/bigtable/data_client.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
class BigtableTestClient : public ::google::cloud::bigtable::DataClient {
public:
std::string const& project_id() const override { return project_id_; }
std::string const& instance_id() const override { return instance_id_; }
void reset() override {
mutex_lock l(mu_);
table_ = Table();
}
grpc::Status MutateRow(
grpc::ClientContext* context,
google::bigtable::v2::MutateRowRequest const& request,
google::bigtable::v2::MutateRowResponse* response) override;
grpc::Status CheckAndMutateRow(
grpc::ClientContext* context,
google::bigtable::v2::CheckAndMutateRowRequest const& request,
google::bigtable::v2::CheckAndMutateRowResponse* response) override;
grpc::Status ReadModifyWriteRow(
grpc::ClientContext* context,
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
google::bigtable::v2::ReadModifyWriteRowResponse* response) override;
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::ReadModifyWriteRowResponse>>
AsyncReadModifyWriteRow(
grpc::ClientContext* context,
google::bigtable::v2::ReadModifyWriteRowRequest const& request,
grpc::CompletionQueue* cq) override;
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
ReadRows(grpc::ClientContext* context,
google::bigtable::v2::ReadRowsRequest const& request) override;
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
SampleRowKeys(
grpc::ClientContext* context,
google::bigtable::v2::SampleRowKeysRequest const& request) override;
std::unique_ptr<
grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
MutateRows(grpc::ClientContext* context,
google::bigtable::v2::MutateRowsRequest const& request) override;
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::MutateRowResponse>>
AsyncMutateRow(grpc::ClientContext* context,
google::bigtable::v2::MutateRowRequest const& request,
grpc::CompletionQueue* cq) override;
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::SampleRowKeysResponse>>
AsyncSampleRowKeys(
::grpc::ClientContext* context,
const ::google::bigtable::v2::SampleRowKeysRequest& request,
::grpc::CompletionQueue* cq, void* tag) override;
std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::MutateRowsResponse>>
AsyncMutateRows(::grpc::ClientContext* context,
const ::google::bigtable::v2::MutateRowsRequest& request,
::grpc::CompletionQueue* cq, void* tag) override;
std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
google::bigtable::v2::CheckAndMutateRowResponse>>
AsyncCheckAndMutateRow(
grpc::ClientContext* context,
const google::bigtable::v2::CheckAndMutateRowRequest& request,
grpc::CompletionQueue* cq) override;
std::unique_ptr<
grpc::ClientAsyncReaderInterface<google::bigtable::v2::ReadRowsResponse>>
AsyncReadRows(grpc::ClientContext* context,
const google::bigtable::v2::ReadRowsRequest& request,
grpc::CompletionQueue* cq, void* tag) override;
std::unique_ptr<grpc::ClientAsyncReaderInterface<
google::bigtable::v2::MutateRowsResponse>>
PrepareAsyncMutateRows(grpc::ClientContext* context,
const google::bigtable::v2::MutateRowsRequest& request,
grpc::CompletionQueue* cq) override;
virtual std::unique_ptr<::grpc::ClientAsyncReaderInterface<
::google::bigtable::v2::ReadRowsResponse>>
PrepareAsyncReadRows(::grpc::ClientContext* context,
const ::google::bigtable::v2::ReadRowsRequest& request,
::grpc::CompletionQueue* cq) override;
std::shared_ptr<grpc::Channel> Channel() override;
private:
friend class SampleRowKeysResponse;
friend class ReadRowsResponse;
friend class MutateRowsResponse;
struct Row {
string row_key;
std::map<string, string> columns;
};
struct Table {
std::map<string, Row> rows;
};
mutex mu_;
const std::string project_id_ = "testproject";
const std::string instance_id_ = "testinstance";
Table table_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_

View File

@ -1,78 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
namespace {
class BigtableTestClientOp : public OpKernel {
public:
explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~BigtableTestClientOp() override {
if (cinfo_.resource_is_private_to_kernel()) {
if (!cinfo_.resource_manager()
->Delete<BigtableClientResource>(cinfo_.container(),
cinfo_.name())
.ok()) {
// Do nothing; the resource can have been deleted by session resets.
}
}
}
void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
if (!initialized_) {
ResourceMgr* mgr = ctx->resource_manager();
OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
BigtableClientResource* resource;
OP_REQUIRES_OK(
ctx,
mgr->LookupOrCreate<BigtableClientResource>(
cinfo_.container(), cinfo_.name(), &resource,
[this, ctx](BigtableClientResource** ret)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::shared_ptr<google::cloud::bigtable::DataClient> client(
new BigtableTestClient());
// Note: must make explicit copies to sequence
// them before the move of client.
string project_id = client->project_id();
string instance_id = client->instance_id();
*ret = new BigtableClientResource(std::move(project_id),
std::move(instance_id),
std::move(client));
return Status::OK();
}));
initialized_ = true;
}
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
ctx, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<BigtableClientResource>()));
}
private:
mutex mu_;
ContainerInfo cinfo_ GUARDED_BY(mu_);
bool initialized_ GUARDED_BY(mu_) = false;
};
REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU),
BigtableTestClientOp);
} // namespace
} // namespace tensorflow

View File

@ -1,349 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
void WriteCell(const string& row, const string& family, const string& column,
const string& value, ::google::cloud::bigtable::Table* table) {
::google::cloud::bigtable::SingleRowMutation mut(row);
mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value));
table->Apply(std::move(mut));
}
TEST(BigtableTestClientTest, EmptyRowRead) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
::google::cloud::bigtable::RowSet rowset;
rowset.Append("r1");
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1));
auto rows = table.ReadRows(std::move(rowset), filter);
EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!";
}
TEST(BigtableTestClientTest, SingleRowWriteAndRead) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
::google::cloud::bigtable::RowSet rowset("r1");
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1));
auto rows = table.ReadRows(std::move(rowset), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "No rows were returned in response!";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
++itr;
EXPECT_EQ(itr, rows.end());
}
TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
::google::cloud::bigtable::RowSet rowset("r1");
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1));
auto rows = table.ReadRows(std::move(rowset), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
++itr;
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
}
TEST(BigtableTestClientTest, MultiRowWriteAndRead) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3");
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1));
auto rows = table.ReadRows(std::move(rowset), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r2");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r3");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
++itr;
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
}
TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1));
auto rows =
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r2");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r3");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
++itr;
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
}
TEST(BigtableTestClientTest, ColumnFiltering) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
// Extra cells
WriteCell("r1", "f2", "c1", "v1", &table);
WriteCell("r2", "f2", "c1", "v2", &table);
WriteCell("r3", "f1", "c2", "v3", &table);
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1),
::google::cloud::bigtable::Filter::FamilyRegex("f1"),
::google::cloud::bigtable::Filter::ColumnRegex("c1"));
auto rows =
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v1");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r2");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v2");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r3");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "v3");
++itr;
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
}
TEST(BigtableTestClientTest, RowKeys) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
// Extra cells
WriteCell("r1", "f2", "c1", "v1", &table);
WriteCell("r2", "f2", "c1", "v2", &table);
WriteCell("r3", "f1", "c2", "v3", &table);
auto filter = ::google::cloud::bigtable::Filter::Chain(
::google::cloud::bigtable::Filter::Latest(1),
::google::cloud::bigtable::Filter::CellsRowLimit(1),
::google::cloud::bigtable::Filter::StripValueTransformer());
auto rows =
table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
auto itr = rows.begin();
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r1");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r2");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "");
++itr;
EXPECT_NE(itr, rows.end()) << "Missing rows";
EXPECT_TRUE(*itr) << "Error reading row: " << itr->status().message();
EXPECT_EQ((*itr)->row_key(), "r3");
EXPECT_EQ((*itr)->cells().size(), 1);
EXPECT_EQ((*itr)->cells()[0].family_name(), "f1");
EXPECT_EQ((*itr)->cells()[0].column_qualifier(), "c1");
EXPECT_EQ((*itr)->cells()[0].value(), "");
++itr;
EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
}
TEST(BigtableTestClientTest, SampleKeys) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
WriteCell("r4", "f1", "c1", "v4", &table);
WriteCell("r5", "f1", "c1", "v5", &table);
auto resp = table.SampleRows();
EXPECT_TRUE(resp.ok());
EXPECT_EQ(3, resp->size());
EXPECT_EQ("r1", string((*resp)[0].row_key));
EXPECT_EQ(0, (*resp)[0].offset_bytes);
EXPECT_EQ("r3", string((*resp)[1].row_key));
EXPECT_EQ(100, (*resp)[1].offset_bytes);
EXPECT_EQ("r5", string((*resp)[2].row_key));
EXPECT_EQ(200, (*resp)[2].offset_bytes);
}
TEST(BigtableTestClientTest, SampleKeysShort) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
auto resp = table.SampleRows();
EXPECT_TRUE(resp.ok());
EXPECT_EQ(1, resp->size());
EXPECT_EQ("r1", string((*resp)[0].row_key));
}
TEST(BigtableTestClientTest, SampleKeysEvenNumber) {
std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
std::make_shared<BigtableTestClient>();
::google::cloud::bigtable::Table table(client_ptr, "test_table");
WriteCell("r1", "f1", "c1", "v1", &table);
WriteCell("r2", "f1", "c1", "v2", &table);
WriteCell("r3", "f1", "c1", "v3", &table);
WriteCell("r4", "f1", "c1", "v4", &table);
auto resp = table.SampleRows();
EXPECT_TRUE(resp.ok());
EXPECT_EQ(2, resp->size());
EXPECT_EQ("r1", string((*resp)[0].row_key));
EXPECT_EQ("r3", string((*resp)[1].row_key));
}
} // namespace
} // namespace tensorflow

View File

@ -1,107 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// TODO(saeta): Add support for setting ClientOptions values.
REGISTER_OP("BigtableClient")
.Attr("project_id: string")
.Attr("instance_id: string")
.Attr("connection_pool_size: int")
.Attr("max_receive_message_size: int = -1")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("client: resource")
.SetShapeFn(shape_inference::ScalarShape);
// TODO(saeta): Add support for Application Profiles.
// See https://cloud.google.com/bigtable/docs/app-profiles for more info.
REGISTER_OP("BigtableTable")
.Input("client: resource")
.Attr("table_name: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("table: resource")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("DatasetToBigtable")
.Input("table: resource")
.Input("input_dataset: variant")
.Input("column_families: string")
.Input("columns: string")
.Input("timestamp: int64")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("BigtableLookupDataset")
.Input("keys_dataset: variant")
.Input("table: resource")
.Input("column_families: string")
.Input("columns: string")
.Output("handle: variant")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BigtablePrefixKeyDataset")
.Input("table: resource")
.Input("prefix: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BigtableRangeKeyDataset")
.Input("table: resource")
.Input("start_key: string")
.Input("end_key: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BigtableSampleKeysDataset")
.Input("table: resource")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BigtableSampleKeyPairsDataset")
.Input("table: resource")
.Input("prefix: string")
.Input("start_key: string")
.Input("end_key: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
// skip incomplete row.)
REGISTER_OP("BigtableScanDataset")
.Input("table: resource")
.Input("prefix: string")
.Input("start_key: string")
.Input("end_key: string")
.Input("column_families: string")
.Input("columns: string")
.Input("probability: float")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
} // namespace tensorflow

View File

@ -1,27 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("BigtableTestClient")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("client: resource")
.SetShapeFn(shape_inference::ScalarShape);
} // namespace tensorflow

View File

@ -1,20 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module contains tests for the bigtable integration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,272 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Bigtable Ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import bigtable
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
from tensorflow.contrib.bigtable.python.ops import bigtable_api
from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.util import compat
_bigtable_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_bigtable_test.so"))
def _ListOfTuplesOfStringsToBytes(values):
return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
class BigtableOpsTest(test.TestCase):
COMMON_ROW_KEYS = ["r1", "r2", "r3"]
COMMON_VALUES = ["v1", "v2", "v3"]
def setUp(self):
self._client = gen_bigtable_test_ops.bigtable_test_client()
table = gen_bigtable_ops.bigtable_table(self._client, "testtable")
self._table = bigtable.BigtableTable("testtable", None, table)
def _makeSimpleDataset(self):
output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS)
output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES)
return dataset_ops.Dataset.zip((output_rows, output_values))
def _writeCommonValues(self, sess):
output_ds = self._makeSimpleDataset()
write_op = self._table.write(output_ds, ["cf1"], ["c1"])
sess.run(write_op)
def runReadKeyTest(self, read_ds):
itr = dataset_ops.make_initializable_iterator(read_ds)
n = itr.get_next()
expected = list(self.COMMON_ROW_KEYS)
expected.reverse()
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
output = sess.run(n)
want = expected.pop()
self.assertEqual(
compat.as_bytes(want), compat.as_bytes(output),
"Unequal at step %d: want: %s, got: %s" % (i, want, output))
def testReadPrefixKeys(self):
self.runReadKeyTest(self._table.keys_by_prefix_dataset("r"))
def testReadRangeKeys(self):
self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4"))
def runScanTest(self, read_ds):
itr = dataset_ops.make_initializable_iterator(read_ds)
n = itr.get_next()
expected_keys = list(self.COMMON_ROW_KEYS)
expected_keys.reverse()
expected_values = list(self.COMMON_VALUES)
expected_values.reverse()
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
output = sess.run(n)
want = expected_keys.pop()
self.assertEqual(
compat.as_bytes(want), compat.as_bytes(output[0]),
"Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0]))
want = expected_values.pop()
self.assertEqual(
compat.as_bytes(want), compat.as_bytes(output[1]),
"Unequal values at step: %d: want: %s, got: %s" % (i, want,
output[1]))
def testScanPrefixStringCol(self):
self.runScanTest(self._table.scan_prefix("r", cf1="c1"))
def testScanPrefixListCol(self):
self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
def testScanPrefixTupleCol(self):
self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
def testScanRangeStringCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
def testScanRangeListCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
def testScanRangeTupleCol(self):
self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
def testLookup(self):
ds = self._table.keys_by_prefix_dataset("r")
ds = ds.apply(self._table.lookup_columns(cf1="c1"))
itr = dataset_ops.make_initializable_iterator(ds)
n = itr.get_next()
expected_keys = list(self.COMMON_ROW_KEYS)
expected_values = list(self.COMMON_VALUES)
expected_tuples = zip(expected_keys, expected_values)
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elem in enumerate(expected_tuples):
output = sess.run(n)
self.assertEqual(
compat.as_bytes(elem[0]), compat.as_bytes(output[0]),
"Unequal keys at step %d: want: %s, got: %s" %
(i, compat.as_bytes(elem[0]), compat.as_bytes(output[0])))
self.assertEqual(
compat.as_bytes(elem[1]), compat.as_bytes(output[1]),
"Unequal values at step %d: want: %s, got: %s" %
(i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
def testSampleKeys(self):
ds = self._table.sample_keys()
itr = dataset_ops.make_initializable_iterator(ds)
n = itr.get_next()
expected_key = self.COMMON_ROW_KEYS[0]
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
output = sess.run(n)
self.assertEqual(
compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output),
"Unequal keys: want: %s, got: %s" % (compat.as_bytes(
self.COMMON_ROW_KEYS[0]), compat.as_bytes(output)))
output = sess.run(n)
self.assertEqual(
compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output),
"Unequal keys: want: %s, got: %s" % (compat.as_bytes(
self.COMMON_ROW_KEYS[2]), compat.as_bytes(output)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
def runSampleKeyPairsTest(self, ds, expected_key_pairs):
itr = dataset_ops.make_initializable_iterator(ds)
n = itr.get_next()
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elems in enumerate(expected_key_pairs):
output = sess.run(n)
self.assertEqual(
compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
"Unequal key pair (first element) at step %d; want: %s, got %s" %
(i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
self.assertEqual(
compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
"Unequal key pair (second element) at step %d; want: %s, got %s" %
(i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
def testSampleKeyPairsSimplePrefix(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="", end="")
expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairsSimpleRange(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="", start="r1", end="r3")
expected_key_pairs = [("r1", "r3")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairsSkipRangePrefix(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r2", start="", end="")
expected_key_pairs = [("r2", "r3")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairsSkipRangeRange(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="", start="r2", end="r3")
expected_key_pairs = [("r2", "r3")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairsOffsetRanges(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="", start="r2", end="r4")
expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairEverything(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="", start="", end="")
expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
self.runSampleKeyPairsTest(ds, expected_key_pairs)
def testSampleKeyPairsPrefixAndStartKey(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="r1", end="")
itr = dataset_ops.make_initializable_iterator(ds)
with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
def testSampleKeyPairsPrefixAndEndKey(self):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="", end="r3")
itr = dataset_ops.make_initializable_iterator(ds)
with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
def testParallelScanPrefix(self):
ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
itr = dataset_ops.make_initializable_iterator(ds)
n = itr.get_next()
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
actual_values = []
for _ in range(len(expected_values)):
output = sess.run(n)
actual_values.append(output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
self.assertItemsEqual(
_ListOfTuplesOfStringsToBytes(expected_values),
_ListOfTuplesOfStringsToBytes(actual_values))
def testParallelScanRange(self):
ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
itr = dataset_ops.make_initializable_iterator(ds)
n = itr.get_next()
with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
actual_values = []
for _ in range(len(expected_values)):
output = sess.run(n)
actual_values.append(output)
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
self.assertItemsEqual(
_ListOfTuplesOfStringsToBytes(expected_values),
_ListOfTuplesOfStringsToBytes(actual_values))
if __name__ == "__main__":
test.main()

View File

@ -1,20 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module contains the Python API for the Cloud Bigtable integration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,708 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Python API for TensorFlow's Cloud Bigtable integration.
TensorFlow has support for reading from and writing to Cloud Bigtable. To use
TensorFlow + Cloud Bigtable integration, first create a BigtableClient to
configure your connection to Cloud Bigtable, and then create a BigtableTable
object to allow you to create numerous `tf.data.Dataset`s to read data, or
write a `tf.data.Dataset` object to the underlying Cloud Bigtable table.
For background on Cloud Bigtable, see: https://cloud.google.com/bigtable .
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six import iteritems
from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
from tensorflow.contrib.util import loader
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.platform import resource_loader
_bigtable_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_bigtable.so"))
class BigtableClient(object):
"""BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
`table` method to open a Bigtable table.
"""
def __init__(self,
project_id,
instance_id,
connection_pool_size=None,
max_receive_message_size=None):
"""Creates a BigtableClient that can be used to open connections to tables.
Args:
project_id: A string representing the GCP project id to connect to.
instance_id: A string representing the Bigtable instance to connect to.
connection_pool_size: (Optional.) A number representing the number of
concurrent connections to the Cloud Bigtable service to make.
max_receive_message_size: (Optional.) The maximum bytes received in a
single gRPC response.
Raises:
ValueError: if the arguments are invalid (e.g. wrong type, or out of
expected ranges (e.g. negative).)
"""
if not isinstance(project_id, str):
raise ValueError("`project_id` must be a string")
self._project_id = project_id
if not isinstance(instance_id, str):
raise ValueError("`instance_id` must be a string")
self._instance_id = instance_id
if connection_pool_size is None:
connection_pool_size = -1
elif connection_pool_size < 1:
raise ValueError("`connection_pool_size` must be positive")
if max_receive_message_size is None:
max_receive_message_size = -1
elif max_receive_message_size < 1:
raise ValueError("`max_receive_message_size` must be positive")
self._connection_pool_size = connection_pool_size
self._resource = gen_bigtable_ops.bigtable_client(
project_id, instance_id, connection_pool_size, max_receive_message_size)
def table(self, name, snapshot=None):
"""Opens a table and returns a `tf.contrib.bigtable.BigtableTable` object.
Args:
name: A `tf.string` `tf.Tensor` name of the table to open.
snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
request the creation of a snapshot. (Note: currently unimplemented.)
Returns:
A `tf.contrib.bigtable.BigtableTable` Python object representing the
operations available on the table.
"""
# TODO(saeta): Implement snapshot functionality.
table = gen_bigtable_ops.bigtable_table(self._resource, name)
return BigtableTable(name, snapshot, table)
class BigtableTable(object):
"""Entry point for reading and writing data in Cloud Bigtable.
This BigtableTable class is the Python representation of the Cloud Bigtable
table within TensorFlow. Methods on this class allow data to be read from and
written to the Cloud Bigtable service in flexible and high performance
manners.
"""
# TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
# TODO(saeta): Consider variant tensors instead of resources (while supporting
# connection pooling).
def __init__(self, name, snapshot, resource):
self._name = name
self._snapshot = snapshot
self._resource = resource
def lookup_columns(self, *args, **kwargs):
"""Retrieves the values of columns for a dataset of keys.
Example usage:
```python
table = bigtable_client.table("my_table")
key_dataset = table.get_keys_prefix("imagenet")
images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
("cf2", "label"),
("cf2", "boundingbox")))
training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
```
Alternatively, you can use keyword arguments to specify the columns to
capture. Example (same as above, rewritten):
```python
table = bigtable_client.table("my_table")
key_dataset = table.get_keys_prefix("imagenet")
images = key_dataset.apply(table.lookup_columns(
cf1="image", cf2=("label", "boundingbox")))
training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
```
Note: certain `kwargs` keys are reserved, and thus, some column families
cannot be identified using the `kwargs` syntax. Instead, please use the
`args` syntax. This list includes:
- 'name'
Note: this list can change at any time.
Args:
*args: A list of tuples containing (column family, column name) pairs.
**kwargs: Column families (keys) and column qualifiers (values).
Returns:
A function that can be passed to `tf.data.Dataset.apply` to retrieve the
values of columns for the rows.
"""
table = self # Capture self
normalized = args
if normalized is None:
normalized = []
if isinstance(normalized, tuple):
normalized = list(normalized)
for key, value in iteritems(kwargs):
if key == "name":
continue
if isinstance(value, str):
normalized.append((key, value))
continue
for col in value:
normalized.append((key, col))
def _apply_fn(dataset):
# TODO(saeta): Verify dataset's types are correct!
return _BigtableLookupDataset(dataset, table, normalized)
return _apply_fn
def keys_by_range_dataset(self, start, end):
"""Retrieves all row keys between start and end.
Note: it does NOT retrieve the values of columns.
Args:
start: The start row key. The row keys for rows after start (inclusive)
will be retrieved.
end: (Optional.) The end row key. Rows up to (but not including) end will
be retrieved. If end is None, all subsequent row keys will be retrieved.
Returns:
A `tf.data.Dataset` containing `tf.string` Tensors corresponding to all
of the row keys between `start` and `end`.
"""
# TODO(saeta): Make inclusive / exclusive configurable?
if end is None:
end = ""
return _BigtableRangeKeyDataset(self, start, end)
def keys_by_prefix_dataset(self, prefix):
"""Retrieves the row keys matching a given prefix.
Args:
prefix: All row keys that begin with `prefix` in the table will be
retrieved.
Returns:
A `tf.data.Dataset`. containing `tf.string` Tensors corresponding to all
of the row keys matching that prefix.
"""
return _BigtablePrefixKeyDataset(self, prefix)
def sample_keys(self):
"""Retrieves a sampling of row keys from the Bigtable table.
This dataset is most often used in conjunction with
`tf.data.experimental.parallel_interleave` to construct a set of ranges for
scanning in parallel.
Returns:
A `tf.data.Dataset` returning string row keys.
"""
return _BigtableSampleKeysDataset(self)
def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
"""Retrieves row (including values) from the Bigtable service.
Rows with row-key prefixed by `prefix` will be retrieved.
Specifying the columns to retrieve for each row is done by either using
kwargs or in the columns parameter. To retrieve values of the columns "c1",
and "c2" from the column family "cfa", and the value of the column "c3"
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
equivalent:
```
table = # ...
ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
("cfa", "c2"),
("cfb", "c3")])
ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
```
Note: only the latest value of a cell will be retrieved.
Args:
prefix: The prefix all row keys must match to be retrieved for prefix-
based scans.
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
A non-1 value indicates to probabilistically sample rows with the
provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
of tuples of strings ("column_family", "column_qualifier").
**kwargs: The column families and columns to read. Keys are treated as
column_families, and values can be either lists of strings, or strings
that are treated as the column qualifier (column name).
Returns:
A `tf.data.Dataset` returning the row keys and the cell contents.
Raises:
ValueError: If the configured probability is unexpected.
"""
probability = _normalize_probability(probability)
normalized = _normalize_columns(columns, kwargs)
return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
def scan_range(self, start, end, probability=None, columns=None, **kwargs):
"""Retrieves rows (including values) from the Bigtable service.
Rows with row-keys between `start` and `end` will be retrieved.
Specifying the columns to retrieve for each row is done by either using
kwargs or in the columns parameter. To retrieve values of the columns "c1",
and "c2" from the column family "cfa", and the value of the column "c3"
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
equivalent:
```
table = # ...
ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
("cfa", "c2"),
("cfb", "c3")])
ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
```
Note: only the latest value of a cell will be retrieved.
Args:
start: The start of the range when scanning by range.
end: (Optional.) The end of the range when scanning by range.
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
A non-1 value indicates to probabilistically sample rows with the
provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
of tuples of strings ("column_family", "column_qualifier").
**kwargs: The column families and columns to read. Keys are treated as
column_families, and values can be either lists of strings, or strings
that are treated as the column qualifier (column name).
Returns:
A `tf.data.Dataset` returning the row keys and the cell contents.
Raises:
ValueError: If the configured probability is unexpected.
"""
probability = _normalize_probability(probability)
normalized = _normalize_columns(columns, kwargs)
return _BigtableScanDataset(self, "", start, end, normalized, probability)
def parallel_scan_prefix(self,
prefix,
num_parallel_scans=None,
probability=None,
columns=None,
**kwargs):
"""Retrieves row (including values) from the Bigtable service at high speed.
Rows with row-key prefixed by `prefix` will be retrieved. This method is
similar to `scan_prefix`, but by contrast performs multiple sub-scans in
parallel in order to achieve higher performance.
Note: The dataset produced by this method is not deterministic!
Specifying the columns to retrieve for each row is done by either using
kwargs or in the columns parameter. To retrieve values of the columns "c1",
and "c2" from the column family "cfa", and the value of the column "c3"
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
equivalent:
```
table = # ...
ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
("cfa", "c2"),
("cfb", "c3")])
ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
```
Note: only the latest value of a cell will be retrieved.
Args:
prefix: The prefix all row keys must match to be retrieved for prefix-
based scans.
num_parallel_scans: (Optional.) The number of concurrent scans against the
Cloud Bigtable instance.
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
A non-1 value indicates to probabilistically sample rows with the
provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
of tuples of strings ("column_family", "column_qualifier").
**kwargs: The column families and columns to read. Keys are treated as
column_families, and values can be either lists of strings, or strings
that are treated as the column qualifier (column name).
Returns:
A `tf.data.Dataset` returning the row keys and the cell contents.
Raises:
ValueError: If the configured probability is unexpected.
"""
probability = _normalize_probability(probability)
normalized = _normalize_columns(columns, kwargs)
ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "")
return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
normalized)
def parallel_scan_range(self,
start,
end,
num_parallel_scans=None,
probability=None,
columns=None,
**kwargs):
"""Retrieves rows (including values) from the Bigtable service.
Rows with row-keys between `start` and `end` will be retrieved. This method
is similar to `scan_range`, but by contrast performs multiple sub-scans in
parallel in order to achieve higher performance.
Note: The dataset produced by this method is not deterministic!
Specifying the columns to retrieve for each row is done by either using
kwargs or in the columns parameter. To retrieve values of the columns "c1",
and "c2" from the column family "cfa", and the value of the column "c3"
from column family "cfb", the following datasets (`ds1`, and `ds2`) are
equivalent:
```
table = # ...
ds1 = table.parallel_scan_range("row_start",
"row_end",
columns=[("cfa", "c1"),
("cfa", "c2"),
("cfb", "c3")])
ds2 = table.parallel_scan_range("row_start", "row_end",
cfa=["c1", "c2"], cfb="c3")
```
Note: only the latest value of a cell will be retrieved.
Args:
start: The start of the range when scanning by range.
end: (Optional.) The end of the range when scanning by range.
num_parallel_scans: (Optional.) The number of concurrent scans against the
Cloud Bigtable instance.
probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
A non-1 value indicates to probabilistically sample rows with the
provided probability.
columns: The columns to read. Note: most commonly, they are expressed as
kwargs. Use the columns value if you are using column families that are
reserved. The value of columns and kwargs are merged. Columns is a list
of tuples of strings ("column_family", "column_qualifier").
**kwargs: The column families and columns to read. Keys are treated as
column_families, and values can be either lists of strings, or strings
that are treated as the column qualifier (column name).
Returns:
A `tf.data.Dataset` returning the row keys and the cell contents.
Raises:
ValueError: If the configured probability is unexpected.
"""
probability = _normalize_probability(probability)
normalized = _normalize_columns(columns, kwargs)
ds = _BigtableSampleKeyPairsDataset(self, "", start, end)
return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
normalized)
def write(self, dataset, column_families, columns, timestamp=None):
"""Writes a dataset to the table.
Args:
dataset: A `tf.data.Dataset` to be written to this table. It must produce
a list of number-of-columns+1 elements, all of which must be strings.
The first value will be used as the row key, and subsequent values will
be used as cell values for the corresponding columns from the
corresponding column_families and columns entries.
column_families: A `tf.Tensor` of `tf.string`s corresponding to the
column names to store the dataset's elements into.
columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
to store the dataset's elements into.
timestamp: (Optional.) An int64 timestamp to write all the values at.
Leave as None to use server-provided timestamps.
Returns:
A `tf.Operation` that can be run to perform the write.
Raises:
ValueError: If there are unexpected or incompatible types, or if the
number of columns and column_families does not match the output of
`dataset`.
"""
if timestamp is None:
timestamp = -1 # Bigtable server provided timestamp.
for tensor_type in nest.flatten(
dataset_ops.get_legacy_output_types(dataset)):
if tensor_type != dtypes.string:
raise ValueError("Not all elements of the dataset were `tf.string`")
for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
if not shape.is_compatible_with(tensor_shape.TensorShape([])):
raise ValueError("Not all elements of the dataset were scalars")
if len(column_families) != len(columns):
raise ValueError("len(column_families) != len(columns)")
if len(nest.flatten(
dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1:
raise ValueError("A column name must be specified for every component of "
"the dataset elements. (e.g.: len(columns) != "
"len(dataset.output_types))")
return gen_bigtable_ops.dataset_to_bigtable(
self._resource,
dataset._variant_tensor, # pylint: disable=protected-access
column_families,
columns,
timestamp)
def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
normalized_probability, normalized_columns):
"""Builds a parallel dataset from a given range.
Args:
ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
num_parallel_scans: The number of concurrent parallel scans to use.
normalized_probability: A number between 0 and 1 for the keep probability.
normalized_columns: The column families and column qualifiers to retrieve.
Returns:
A `tf.data.Dataset` representing the result of the parallel scan.
"""
if num_parallel_scans is None:
num_parallel_scans = 50
ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
def _interleave_fn(start, end):
return _BigtableScanDataset(
self,
prefix="",
start=start,
end=end,
normalized=normalized_columns,
probability=normalized_probability)
# Note prefetch_input_elements must be set in order to avoid rpc timeouts.
ds = ds.apply(
interleave_ops.parallel_interleave(
_interleave_fn,
cycle_length=num_parallel_scans,
sloppy=True,
prefetch_input_elements=1))
return ds
def _normalize_probability(probability):
if probability is None:
probability = 1.0
if isinstance(probability, float) and (probability <= 0.0 or
probability > 1.0):
raise ValueError("probability must be in the range (0, 1].")
return probability
def _normalize_columns(columns, provided_kwargs):
"""Converts arguments (columns, and kwargs dict) to C++ representation.
Args:
columns: a datastructure containing the column families and qualifier to
retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
strings.
provided_kwargs: a dictionary containing the column families and qualifiers
to retrieve
Returns:
A list of pairs of column family+qualifier to retrieve.
Raises:
ValueError: If there are no cells to retrieve or the columns are in an
incorrect format.
"""
normalized = columns
if normalized is None:
normalized = []
if isinstance(normalized, tuple):
if len(normalized) == 2:
normalized = [normalized]
else:
raise ValueError("columns was a tuple of inappropriate length")
for key, value in iteritems(provided_kwargs):
if key == "name":
continue
if isinstance(value, string_types):
normalized.append((key, value))
continue
for col in value:
normalized.append((key, col))
if not normalized:
raise ValueError("At least one column + column family must be specified.")
return normalized
class _BigtableKeyDataset(dataset_ops.DatasetSource):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
"""
def __init__(self, table, variant_tensor):
"""Constructs a _BigtableKeyDataset.
Args:
table: a Bigtable class.
variant_tensor: DT_VARIANT representation of the dataset.
"""
super(_BigtableKeyDataset, self).__init__(variant_tensor)
self._table = table
@property
def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.string)
class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
"""_BigtablePrefixKeyDataset represents looking up keys by prefix.
"""
def __init__(self, table, prefix):
self._prefix = prefix
variant_tensor = gen_bigtable_ops.bigtable_prefix_key_dataset(
table=table._resource, # pylint: disable=protected-access
prefix=self._prefix)
super(_BigtablePrefixKeyDataset, self).__init__(table, variant_tensor)
class _BigtableRangeKeyDataset(_BigtableKeyDataset):
"""_BigtableRangeKeyDataset represents looking up keys by range.
"""
def __init__(self, table, start, end):
self._start = start
self._end = end
variant_tensor = gen_bigtable_ops.bigtable_range_key_dataset(
table=table._resource, # pylint: disable=protected-access
start_key=self._start,
end_key=self._end)
super(_BigtableRangeKeyDataset, self).__init__(table, variant_tensor)
class _BigtableSampleKeysDataset(_BigtableKeyDataset):
"""_BigtableSampleKeysDataset represents a sampling of row keys.
"""
# TODO(saeta): Expose the data size offsets into the keys.
def __init__(self, table):
variant_tensor = gen_bigtable_ops.bigtable_sample_keys_dataset(
table=table._resource) # pylint: disable=protected-access
super(_BigtableSampleKeysDataset, self).__init__(table, variant_tensor)
class _BigtableLookupDataset(dataset_ops.DatasetSource):
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
"""
def __init__(self, dataset, table, normalized):
self._num_outputs = len(normalized) + 1 # 1 for row key
self._dataset = dataset
self._table = table
self._normalized = normalized
self._column_families = [i[0] for i in normalized]
self._columns = [i[1] for i in normalized]
variant_tensor = gen_bigtable_ops.bigtable_lookup_dataset(
keys_dataset=self._dataset._variant_tensor, # pylint: disable=protected-access
table=self._table._resource, # pylint: disable=protected-access
column_families=self._column_families,
columns=self._columns)
super(_BigtableLookupDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return tuple([tensor_spec.TensorSpec([], dtypes.string)] *
self._num_outputs)
class _BigtableScanDataset(dataset_ops.DatasetSource):
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
"""
def __init__(self, table, prefix, start, end, normalized, probability):
self._table = table
self._prefix = prefix
self._start = start
self._end = end
self._column_families = [i[0] for i in normalized]
self._columns = [i[1] for i in normalized]
self._probability = probability
self._num_outputs = len(normalized) + 1 # 1 for row key
variant_tensor = gen_bigtable_ops.bigtable_scan_dataset(
table=self._table._resource, # pylint: disable=protected-access
prefix=self._prefix,
start_key=self._start,
end_key=self._end,
column_families=self._column_families,
columns=self._columns,
probability=self._probability)
super(_BigtableScanDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return tuple([tensor_spec.TensorSpec([], dtypes.string)] *
self._num_outputs)
class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
"""_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
"""
def __init__(self, table, prefix, start, end):
self._table = table
self._prefix = prefix
self._start = start
self._end = end
variant_tensor = gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
table=self._table._resource, # pylint: disable=protected-access
prefix=self._prefix,
start_key=self._start,
end_key=self._end)
super(_BigtableSampleKeyPairsDataset, self).__init__(variant_tensor)
@property
def element_spec(self):
return (tensor_spec.TensorSpec([], dtypes.string),
tensor_spec.TensorSpec([], dtypes.string))

View File

@ -1,614 +0,0 @@
# TensorFlow code for training gradient boosted trees.
load("//tensorflow:tensorflow.bzl", "py_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
package_group(name = "friends")
cc_library(
name = "boosted_trees_kernels",
deps = [
":model_ops_kernels",
":prediction_ops_kernels",
":quantile_ops_kernels",
":split_handler_ops_kernels",
":stats_accumulator_ops_kernels",
":training_ops_kernels",
],
alwayslink = 1,
)
cc_library(
name = "boosted_trees_ops_op_lib",
deps = [
":model_ops_op_lib",
":prediction_ops_op_lib",
":quantile_ops_op_lib",
":split_handler_ops_op_lib",
":stats_accumulator_ops_op_lib",
":training_ops_op_lib",
],
)
py_library(
name = "init_py",
srcs = [
"__init__.py",
"python/__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":boosted_trees_ops_py",
":losses",
],
)
py_library(
name = "losses",
srcs = ["python/utils/losses.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
],
)
py_test(
name = "losses_test",
size = "small",
srcs = ["python/utils/losses_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":losses",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
],
)
py_library(
name = "gbdt_batch",
srcs = [
"python/training/functions/gbdt_batch.py",
],
srcs_version = "PY2AND3",
deps = [
":gen_model_ops_py",
"//tensorflow/contrib/boosted_trees:batch_ops_utils_py",
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_py",
"//tensorflow/contrib/boosted_trees/lib:categorical_split_handler",
"//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:stateless_random_ops",
"//tensorflow/python:summary",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/feature_column",
],
)
py_test(
name = "gbdt_batch_test",
size = "medium",
srcs = ["python/training/functions/gbdt_batch_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
tags = [
"nofwdcompat", # b/137641346
"notsan", # b/62863147
],
deps = [
":gbdt_batch",
":losses",
":model_ops_py",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:variables",
],
)
# Kernel tests
py_test(
name = "model_ops_test",
size = "small",
srcs = ["python/kernel_tests/model_ops_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":model_ops_py",
":prediction_ops_py",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_test(
name = "prediction_ops_test",
size = "small",
srcs = ["python/kernel_tests/prediction_ops_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":model_ops_py",
":prediction_ops_py",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//third_party/py/numpy",
],
)
py_test(
name = "quantile_ops_test",
size = "small",
srcs = ["python/kernel_tests/quantile_ops_test.py"],
python_version = "PY2",
shard_count = 3,
srcs_version = "PY2AND3",
deps = [
":quantile_ops_py",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//third_party/py/numpy",
],
)
py_test(
name = "split_handler_ops_test",
size = "small",
srcs = ["python/kernel_tests/split_handler_ops_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":split_handler_ops_py",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_test(
name = "stats_accumulator_ops_test",
size = "small",
srcs = ["python/kernel_tests/stats_accumulator_ops_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":stats_accumulator_ops_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_shape",
],
)
py_test(
name = "training_ops_test",
size = "small",
srcs = ["python/kernel_tests/training_ops_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":model_ops_py",
":training_ops_py",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//third_party/py/numpy",
],
)
# Ops
py_library(
name = "batch_ops_utils_py",
srcs = ["python/ops/batch_ops_utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
],
)
tf_custom_op_py_library(
name = "boosted_trees_ops_loader",
srcs = ["python/ops/boosted_trees_ops_loader.py"],
dso = [":python/ops/_boosted_trees_ops.so"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:errors",
"//tensorflow/python:platform",
],
)
py_library(
name = "boosted_trees_ops_py",
srcs_version = "PY2AND3",
deps = [
":model_ops_py",
":prediction_ops_py",
":quantile_ops_py",
":split_handler_ops_py",
":stats_accumulator_ops_py",
":training_ops_py",
],
)
# Model Ops.
tf_gen_op_libs(
op_lib_names = ["model_ops"],
)
tf_gen_op_wrapper_py(
name = "gen_model_ops_py",
out = "python/ops/gen_model_ops.py",
deps = [":model_ops_op_lib"],
)
tf_custom_op_py_library(
name = "model_ops_py",
srcs = ["python/ops/model_ops.py"],
kernels = [
":model_ops_kernels",
":model_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":boosted_trees_ops_loader",
":gen_model_ops_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:resources",
"//tensorflow/python:training",
],
)
tf_kernel_library(
name = "model_ops_kernels",
srcs = ["kernels/model_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_custom_op_library(
name = "python/ops/_boosted_trees_ops.so",
srcs = [
"kernels/model_ops.cc",
"kernels/prediction_ops.cc",
"kernels/quantile_ops.cc",
"kernels/split_handler_ops.cc",
"kernels/stats_accumulator_ops.cc",
"kernels/training_ops.cc",
"ops/model_ops.cc",
"ops/prediction_ops.cc",
"ops/quantile_ops.cc",
"ops/split_handler_ops.cc",
"ops/stats_accumulator_ops.cc",
"ops/training_ops.cc",
],
deps = [
"//tensorflow/contrib/boosted_trees/lib:example_partitioner",
"//tensorflow/contrib/boosted_trees/lib:models",
"//tensorflow/contrib/boosted_trees/lib:node-stats",
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/contrib/boosted_trees/resources:stamped_resource",
],
)
# Split handler Ops.
tf_gen_op_libs(
op_lib_names = ["split_handler_ops"],
)
tf_gen_op_wrapper_py(
name = "gen_split_handler_ops_py",
out = "python/ops/gen_split_handler_ops.py",
deps = [
":split_handler_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "split_handler_ops_py",
srcs = ["python/ops/split_handler_ops.py"],
kernels = [
":split_handler_ops_kernels",
":split_handler_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":boosted_trees_ops_loader",
":gen_split_handler_ops_py",
],
)
tf_kernel_library(
name = "split_handler_ops_kernels",
srcs = ["kernels/split_handler_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:node-stats",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
alwayslink = 1,
)
# Training Ops.
tf_gen_op_libs(
op_lib_names = [
"training_ops",
],
deps = ["//tensorflow/contrib/boosted_trees/proto:learner_proto_cc"],
)
tf_gen_op_wrapper_py(
name = "gen_training_ops_py",
out = "python/ops/gen_training_ops.py",
deps = [
":training_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "training_ops_py",
srcs = ["python/ops/training_ops.py"],
kernels = [
":training_ops_kernels",
":training_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":boosted_trees_ops_loader",
":gen_training_ops_py",
],
)
tf_kernel_library(
name = "training_ops_kernels",
srcs = ["kernels/training_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/core:framework_headers_lib",
],
alwayslink = 1,
)
# Prediction Ops.
tf_gen_op_libs(
op_lib_names = ["prediction_ops"],
deps = ["//tensorflow/contrib/boosted_trees/proto:learner_proto_cc"],
)
tf_gen_op_wrapper_py(
name = "gen_prediction_ops_py",
out = "python/ops/gen_prediction_ops.py",
deps = [
":prediction_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "prediction_ops_py",
srcs = ["python/ops/prediction_ops.py"],
kernels = [
":prediction_ops_kernels",
":prediction_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":boosted_trees_ops_loader",
":gen_prediction_ops_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
tf_kernel_library(
name = "prediction_ops_kernels",
srcs = ["kernels/prediction_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:example_partitioner",
"//tensorflow/contrib/boosted_trees/lib:models",
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
# Quantile ops
tf_gen_op_libs(
op_lib_names = ["quantile_ops"],
)
tf_gen_op_wrapper_py(
name = "gen_quantile_ops_py_wrap",
out = "python/ops/gen_quantile_ops.py",
deps = [
":quantile_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "quantile_ops_py",
srcs = ["python/ops/quantile_ops.py"],
kernels = [
":quantile_ops_kernels",
":quantile_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":batch_ops_utils_py",
":boosted_trees_ops_loader",
":gen_quantile_ops_py_wrap",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
],
)
tf_kernel_library(
name = "quantile_ops_kernels",
srcs = ["kernels/quantile_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/lib:weighted_quantiles",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/core:framework_headers_lib",
],
alwayslink = 1,
)
# Stats Accumulator ops
tf_gen_op_libs(
op_lib_names = ["stats_accumulator_ops"],
)
tf_gen_op_wrapper_py(
name = "gen_stats_accumulator_ops_py_wrap",
out = "python/ops/gen_stats_accumulator_ops.py",
deps = [
":stats_accumulator_ops_op_lib",
],
)
tf_custom_op_py_library(
name = "stats_accumulator_ops_py",
srcs = ["python/ops/stats_accumulator_ops.py"],
kernels = [
":stats_accumulator_ops_kernels",
":stats_accumulator_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":batch_ops_utils_py",
":boosted_trees_ops_loader",
":gen_stats_accumulator_ops_py_wrap",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:resources",
"//tensorflow/python:training",
],
)
tf_kernel_library(
name = "stats_accumulator_ops_kernels",
srcs = ["kernels/stats_accumulator_ops.cc"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/resources:stamped_resource",
"//tensorflow/core:framework_headers_lib",
],
alwayslink = 1,
)
# Pip
py_library(
name = "boosted_trees_pip",
deps = [
":init_py",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy",
"//tensorflow/contrib/boosted_trees/estimator_batch:dnn_tree_combined_estimator",
"//tensorflow/contrib/boosted_trees/estimator_batch:init_py",
"//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks",
"//tensorflow/contrib/boosted_trees/lib:categorical_split_handler",
"//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
],
)

View File

@ -1,11 +0,0 @@
# TF Boosted Trees (TFBT)
TF Boosted trees is an implementation of a gradient boosting algorithm with
trees used as weak learners.
## Examples
Folder "examples" demonstrates how TFBT estimators can be used for various
problems. Namely, it contains:
* binary_mnist.py - an example on how to use TFBT for binary classification.
* mnist.py - a multiclass example.
* boston.py - a regression example.

View File

@ -1,22 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient boosted trees implementation in tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.boosted_trees.python import *
# pylint: enable=unused-import,wildcard-import

View File

@ -1,220 +0,0 @@
# This directory contains estimators to train and run inference on
# gradient boosted trees on top of TensorFlow.
load("//tensorflow:tensorflow.bzl", "py_test")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
py_library(
name = "init_py",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":custom_export_strategy",
":custom_loss_head",
":distillation_loss",
":estimator",
":model",
":trainer_hooks",
],
)
py_library(
name = "model",
srcs = ["model.py"],
srcs_version = "PY2AND3",
deps = [
":estimator_utils",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees:model_ops_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:training_util",
],
)
py_library(
name = "trainer_hooks",
srcs = ["trainer_hooks.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
],
)
py_library(
name = "estimator_utils",
srcs = ["estimator_utils.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
],
)
py_test(
name = "trainer_hooks_test",
size = "small",
srcs = ["trainer_hooks_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
tags = [
# TODO(b/134605801): Re-enable this test.
"no_oss",
],
deps = [
":trainer_hooks",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
py_library(
name = "custom_loss_head",
srcs = ["custom_loss_head.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
],
)
py_library(
name = "custom_export_strategy",
srcs = ["custom_export_strategy.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:tag_constants",
],
)
py_test(
name = "custom_export_strategy_test",
size = "small",
srcs = ["custom_export_strategy_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":custom_export_strategy",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "estimator",
srcs = ["estimator.py"],
srcs_version = "PY2AND3",
deps = [
":custom_loss_head",
":estimator_utils",
":model",
"//tensorflow/contrib/boosted_trees:losses",
"//tensorflow/contrib/learn",
"//tensorflow/python:math_ops",
],
)
py_library(
name = "dnn_tree_combined_estimator",
srcs = ["dnn_tree_combined_estimator.py"],
srcs_version = "PY2AND3",
deps = [
":distillation_loss",
":estimator_utils",
":model",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees:model_ops_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
],
)
py_library(
name = "distillation_loss",
srcs = ["distillation_loss.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
],
)
py_test(
name = "dnn_tree_combined_estimator_test",
size = "medium",
timeout = "long",
srcs = ["dnn_tree_combined_estimator_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
tags = [
"no_gpu",
"no_pip_gpu",
"notsan",
],
deps = [
":dnn_tree_combined_estimator",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
],
)
py_test(
name = "estimator_test",
size = "medium",
srcs = ["estimator_test.py"],
python_version = "PY2",
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_gpu",
"no_pip_gpu",
"notsan",
],
deps = [
":estimator",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
],
)

View File

@ -1,22 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient boosted trees implementation in tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.boosted_trees.estimator_batch import *
# pylint: enable=unused-import,wildcard-import

View File

@ -1,267 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Strategy to export custom proto formats."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2
from tensorflow.contrib.decision_trees.proto import generic_tree_model_pb2
from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader as saved_model_loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.util import compat
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"
def make_custom_export_strategy(name,
convert_fn,
feature_columns,
export_input_fn,
use_core_columns=False,
feature_engineering_fn=None,
default_output_alternative_key=None):
"""Makes custom exporter of GTFlow tree format.
Args:
name: A string, for the name of the export strategy.
convert_fn: A function that converts the tree proto to desired format and
saves it to the desired location. Can be None to skip conversion.
feature_columns: A list of feature columns.
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
use_core_columns: A boolean, whether core feature columns were used.
feature_engineering_fn: Feature eng function to be called on the input.
default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head.
Not needed for single-headed models.
Returns:
An `ExportStrategy`.
"""
base_strategy = saved_model_export_utils.make_export_strategy(
serving_input_fn=export_input_fn,
strip_default_attrs=True,
default_output_alternative_key=default_output_alternative_key)
input_fn = export_input_fn()
features = input_fn.features
if feature_engineering_fn is not None:
features, _ = feature_engineering_fn(features, labels=None)
(sorted_feature_names, dense_floats, sparse_float_indices, _, _,
sparse_int_indices, _, _) = gbdt_batch.extract_features(
features, feature_columns, use_core_columns)
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
"""A wrapper to export to SavedModel, and convert it to other formats."""
result_dir = base_strategy.export(estimator, export_dir,
checkpoint_path,
eval_result)
with ops.Graph().as_default() as graph:
with tf_session.Session(graph=graph) as sess:
saved_model_loader.load(
sess, [tag_constants.SERVING], result_dir)
# Note: This is GTFlow internal API and might change.
ensemble_model = graph.get_operation_by_name(
"ensemble_model/TreeEnsembleSerialize")
_, dfec_str = sess.run(ensemble_model.outputs)
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(
compat.as_bytes(result_dir), compat.as_bytes("assets.extra"))
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(
compat.as_bytes(assets_dir),
compat.as_bytes("feature_importances")), "w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
return export_strategy.ExportStrategy(
name, export_fn, strip_default_attrs=True)
def convert_to_universal_format(dtec, sorted_feature_names,
num_dense, num_sparse_float,
num_sparse_int,
feature_name_to_proto=None):
"""Convert GTFlow trees to universal format."""
del num_sparse_int # unused.
model_and_features = generic_tree_model_pb2.ModelAndFeatures()
# TODO(jonasz): Feature descriptions should contain information about how each
# feature is processed before it's fed to the model (e.g. bucketing
# information). As of now, this serves as a list of features the model uses.
for feature_name in sorted_feature_names:
if not feature_name_to_proto:
model_and_features.features[feature_name].SetInParent()
else:
model_and_features.features[feature_name].CopyFrom(
feature_name_to_proto[feature_name])
model = model_and_features.model
model.ensemble.summation_combination_technique.SetInParent()
for tree_idx in range(len(dtec.trees)):
gtflow_tree = dtec.trees[tree_idx]
tree_weight = dtec.tree_weights[tree_idx]
member = model.ensemble.members.add()
member.submodel_id.value = tree_idx
tree = member.submodel.decision_tree
for node_idx in range(len(gtflow_tree.nodes)):
gtflow_node = gtflow_tree.nodes[node_idx]
node = tree.nodes.add()
node_type = gtflow_node.WhichOneof("node")
node.node_id.value = node_idx
if node_type == "leaf":
leaf = gtflow_node.leaf
if leaf.HasField("vector"):
for weight in leaf.vector.value:
new_value = node.leaf.vector.value.add()
new_value.float_value = weight * tree_weight
else:
for index, weight in zip(
leaf.sparse_vector.index, leaf.sparse_vector.value):
new_value = node.leaf.sparse_vector.sparse_value[index]
new_value.float_value = weight * tree_weight
else:
node = node.binary_node
# Binary nodes here.
if node_type == "dense_float_binary_split":
split = gtflow_node.dense_float_binary_split
feature_id = split.feature_column
inequality_test = node.inequality_left_child_test
inequality_test.feature_id.id.value = sorted_feature_names[feature_id]
inequality_test.type = (
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
elif node_type == "sparse_float_binary_split_default_left":
split = gtflow_node.sparse_float_binary_split_default_left.split
node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT)
feature_id = split.feature_column + num_dense
inequality_test = node.inequality_left_child_test
inequality_test.feature_id.id.value = (
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
(sorted_feature_names[feature_id], split.dimension_id))
model_and_features.features.pop(sorted_feature_names[feature_id])
(model_and_features.features[inequality_test.feature_id.id.value]
.SetInParent())
inequality_test.type = (
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
elif node_type == "sparse_float_binary_split_default_right":
split = gtflow_node.sparse_float_binary_split_default_right.split
node.default_direction = (
generic_tree_model_pb2.BinaryNode.RIGHT)
# TODO(nponomareva): adjust this id assignement when we allow multi-
# column sparse tensors.
feature_id = split.feature_column + num_dense
inequality_test = node.inequality_left_child_test
inequality_test.feature_id.id.value = (
_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE %
(sorted_feature_names[feature_id], split.dimension_id))
model_and_features.features.pop(sorted_feature_names[feature_id])
(model_and_features.features[inequality_test.feature_id.id.value]
.SetInParent())
inequality_test.type = (
generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL)
inequality_test.threshold.float_value = split.threshold
elif node_type == "categorical_id_binary_split":
split = gtflow_node.categorical_id_binary_split
node.default_direction = generic_tree_model_pb2.BinaryNode.RIGHT
feature_id = split.feature_column + num_dense + num_sparse_float
categorical_test = (
generic_tree_model_extensions_pb2.MatchingValuesTest())
categorical_test.feature_id.id.value = sorted_feature_names[
feature_id]
matching_id = categorical_test.value.add()
matching_id.int64_value = split.feature_id
node.custom_left_child_test.Pack(categorical_test)
elif (node_type == "oblivious_dense_float_binary_split" or
node_type == "oblivious_categorical_id_binary_split"):
raise ValueError("Universal tree format doesn't support oblivious "
"trees")
else:
raise ValueError("Unexpected node type %s" % node_type)
node.left_child_id.value = split.left_id
node.right_child_id.value = split.right_id
return model_and_features
def _get_feature_importances(dtec, feature_names, num_dense_floats,
num_sparse_float, num_sparse_int):
"""Export the feature importance per feature column."""
del num_sparse_int # Unused.
sums = collections.defaultdict(lambda: 0)
for tree_idx in range(len(dtec.trees)):
tree = dtec.trees[tree_idx]
for tree_node in tree.nodes:
node_type = tree_node.WhichOneof("node")
if node_type == "dense_float_binary_split":
split = tree_node.dense_float_binary_split
split_column = feature_names[split.feature_column]
elif node_type == "sparse_float_binary_split_default_left":
split = tree_node.sparse_float_binary_split_default_left.split
split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
feature_names[split.feature_column + num_dense_floats],
split.dimension_id)
elif node_type == "sparse_float_binary_split_default_right":
split = tree_node.sparse_float_binary_split_default_right.split
split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % (
feature_names[split.feature_column + num_dense_floats],
split.dimension_id)
elif node_type == "categorical_id_binary_split":
split = tree_node.categorical_id_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "oblivious_dense_float_binary_split":
split = tree_node.oblivious_dense_float_binary_split
split_column = feature_names[split.feature_column]
elif node_type == "oblivious_categorical_id_binary_split":
split = tree_node.oblivious_categorical_id_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "categorical_id_set_membership_binary_split":
split = tree_node.categorical_id_set_membership_binary_split
split_column = feature_names[split.feature_column + num_dense_floats +
num_sparse_float]
elif node_type == "leaf":
assert tree_node.node_metadata.gain == 0
continue
else:
raise ValueError("Unexpected split type %s" % node_type)
# Apply shrinkage factor. It is important since it is not always uniform
# across different trees.
sums[split_column] += (
tree_node.node_metadata.gain * dtec.tree_weights[tree_idx])
return dict(sums)

View File

@ -1,364 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the conversion code and for feature importances export.
Tests that cover conversion from TFBT format to a tensorflow.contrib.
decision_tree generic_tree_model format and feature importances export.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from google.protobuf import text_format
from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class ConvertModelTest(test_util.TensorFlowTestCase):
def _make_trees(self):
dtec_str = """
trees {
nodes {
leaf {
vector {
value: -1
}
}
}
}
trees {
nodes {
dense_float_binary_split {
feature_column: 0
threshold: 1740.0
left_id: 1
right_id: 2
}
node_metadata {
gain: 500
}
}
nodes {
leaf {
vector {
value: 0.6
}
}
}
nodes {
sparse_float_binary_split_default_left {
split {
feature_column: 0
threshold: 1500.0
left_id: 3
right_id: 4
}
}
node_metadata {
gain: 500
}
}
nodes {
categorical_id_binary_split {
feature_column: 0
feature_id: 5
left_id: 5
right_id: 6
}
node_metadata {
gain: 500
}
}
nodes {
leaf {
vector {
value: 0.8
}
}
}
nodes {
leaf {
vector {
value: 0.5
}
}
}
nodes {
sparse_float_binary_split_default_right {
split {
feature_column: 1
dimension_id:3
threshold: -0.4
left_id: 7
right_id: 8
}
}
node_metadata {
gain: 3600
}
}
nodes {
leaf {
vector {
value: 0.36
}
}
}
nodes {
leaf {
vector {
value: 18
}
}
}
}
tree_weights: 1.0
tree_weights: 0.1
"""
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(dtec_str, dtec)
feature_columns = [
"feature_b",
"feature_a",
"feature_a_m",
"feature_d",
]
return dtec, feature_columns
def testConvertModel(self):
dtec, feature_columns = self._make_trees()
# Assume 2 sparse float columns, one with 1 dimension, the second one with
# 5 dimensions.
# The feature columns in the order they were added.
out = custom_export_strategy.convert_to_universal_format(
dtec, feature_columns, 1, 2, 1)
# Features a and a_m are sparse float features, a_m is multidimensional.
expected_tree = """
features { key: "feature_a_0" }
features { key: "feature_a_m_3" }
features { key: "feature_b" }
features { key: "feature_d" }
model {
ensemble {
summation_combination_technique {
}
members {
submodel {
decision_tree {
nodes {
node_id {
}
leaf {
vector {
value {
float_value: -1.0
}
}
}
}
}
}
submodel_id {
}
}
members {
submodel {
decision_tree {
nodes {
node_id {
}
binary_node {
left_child_id {
value: 1
}
right_child_id {
value: 2
}
inequality_left_child_test {
feature_id {
id {
value: "feature_b"
}
}
threshold {
float_value: 1740.0
}
}
}
}
nodes {
node_id {
value: 1
}
leaf {
vector {
value {
float_value: 0.06
}
}
}
}
nodes {
node_id {
value: 2
}
binary_node {
left_child_id {
value: 3
}
right_child_id {
value: 4
}
inequality_left_child_test {
feature_id {
id {
value: "feature_a_0"
}
}
threshold {
float_value: 1500.0
}
}
}
}
nodes {
node_id {
value: 3
}
binary_node {
left_child_id {
value: 5
}
right_child_id {
value: 6
}
default_direction: RIGHT
custom_left_child_test {
[type.googleapis.com/tensorflow.decision_trees.MatchingValuesTest] {
feature_id {
id {
value: "feature_d"
}
}
value {
int64_value: 5
}
}
}
}
}
nodes {
node_id {
value: 4
}
leaf {
vector {
value {
float_value: 0.08
}
}
}
}
nodes {
node_id {
value: 5
}
leaf {
vector {
value {
float_value: 0.05
}
}
}
}
nodes {
node_id {
value: 6
}
binary_node {
left_child_id {
value: 7
}
right_child_id {
value: 8
}
default_direction: RIGHT
inequality_left_child_test {
feature_id {
id {
value: "feature_a_m_3"
}
}
threshold {
float_value: -0.4
}
}
}
}
nodes {
node_id {
value: 7
}
leaf {
vector {
value {
float_value: 0.036
}
}
}
}
nodes {
node_id {
value: 8
}
leaf {
vector {
value {
float_value: 1.8
}
}
}
}
}
}
submodel_id {
value: 1
}
}
}
}"""
self.assertProtoEquals(expected_tree, out)
def testFeatureImportance(self):
dtec, feature_columns = self._make_trees()
feature_importances = custom_export_strategy._get_feature_importances(
dtec, feature_columns, 1, 2, 1)
self.assertItemsEqual(
["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"],
feature_importances.keys())
self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4)
self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4)
self.assertAlmostEqual(
360.0, feature_importances["feature_a_m_3"], places=4)
if __name__ == "__main__":
googletest.main()

View File

@ -1,73 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of `head.Head` with custom loss and link function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
class CustomLossHead(head_lib._RegressionHead): # pylint: disable=protected-access
"""A Head object with custom loss function and link function."""
def __init__(self,
loss_fn,
link_fn,
logit_dimension,
head_name=None,
weight_column_name=None,
metrics_fn=None):
"""`Head` for specifying arbitrary loss function.
Args:
loss_fn: Loss function.
link_fn: Function that converts logits to prediction.
logit_dimension: Number of dimensions for the logits.
head_name: name of the head. Predictions, summary, metrics keys are
suffixed by `"/" + head_name` and the default variable scope is
`head_name`.
weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
metrics_fn: a function that takes predictions dict, labels and weights and
returns a dictionary of metrics to be calculated.
"""
def loss_wrapper(labels, logits, weight_tensor):
if weight_tensor is None:
weight_tensor = array_ops.ones(
shape=[array_ops.shape(labels)[0], 1], dtype=dtypes.float32)
weighted_loss, _ = loss_fn(labels, weight_tensor, logits)
average_loss = math_ops.reduce_mean(weighted_loss)
return average_loss, average_loss / math_ops.reduce_mean(weight_tensor)
super(CustomLossHead, self).__init__(
loss_fn=loss_wrapper,
link_fn=link_fn,
head_name=head_name,
weight_column_name=weight_column_name,
enable_centered_bias=False,
label_dimension=logit_dimension)
self._metrics_fn = metrics_fn
def _metrics(self, eval_loss, predictions, labels, weights):
if self._metrics_fn is not None:
return self._metrics_fn(predictions, labels, weights)

View File

@ -1,75 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utill functions for distillation loss.
The distillation loss_fn will be called with the following:
Args:
dnn_logits: Tensor of logits from the dnn, treated as the "target". This will
be the output of a call to tf.stop_gradient().
tree_logits: Tensor of logits from the tree, treated as the "predictions".
example_weights: Tensor of example weights, or a single scalar.
Returns:
A scalar indicating the reduced loss for that batch of examples.
Note: we calls the loss_fn defined in contrib head, which is computing two
losses, first one for training and second one for reporting. We only take the
first one here.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
def _logits_to_label_for_tree(logits, n_classes):
if n_classes == 2:
return math_ops.sigmoid(logits)
else:
return nn.softmax(logits)
def create_dnn_to_tree_squared_loss_fn(n_classes):
"""Returns a squared loss function for dnn to tree distillation."""
def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights):
return head_lib._mean_squared_loss( # pylint: disable=protected-access
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
logits=_logits_to_label_for_tree(tree_logits, n_classes),
weights=example_weights)[0]
return _dnn_to_tree_squared_loss
def create_dnn_to_tree_cross_entropy_loss_fn(n_classes):
"""Returns a cross entropy loss function for dnn to tree distillation."""
def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights):
if n_classes == 2:
return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
logits=tree_logits,
weights=example_weights)[0]
else:
return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access
labels=_logits_to_label_for_tree(dnn_logits, n_classes),
logits=tree_logits,
weights=example_weights)[0]
return _dnn_to_tree_cross_entropy_loss

View File

@ -1,859 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow estimators for combined DNN + GBDT training model.
The combined model trains a DNN first, then trains boosted trees to boost the
logits of the DNN. The input layer of the DNN (including the embeddings learned
over sparse features) can optionally be provided to the boosted trees as
an additional input feature.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.contrib import layers
from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
_DNN_LEARNING_RATE = 0.001
def _get_optimizer(optimizer):
if callable(optimizer):
return optimizer()
else:
return optimizer
def _add_hidden_layer_summary(value, tag):
summary.scalar("%s_fraction_of_zero_values" % tag, nn.zero_fraction(value))
summary.histogram("%s_activation" % tag, value)
def _dnn_tree_combined_model_fn(
features,
labels,
mode,
head,
dnn_hidden_units,
dnn_feature_columns,
tree_learner_config,
num_trees,
tree_examples_per_layer,
config=None,
dnn_optimizer="Adagrad",
dnn_activation_fn=nn.relu,
dnn_dropout=None,
dnn_input_layer_partitioner=None,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=10000,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
override_global_step_value=None):
"""DNN and GBDT combined model_fn.
Args:
features: `dict` of `Tensor` objects.
labels: Labels used to train on.
mode: Mode we are in. (TRAIN/EVAL/INFER)
head: A `Head` instance.
dnn_hidden_units: List of hidden units per layer.
dnn_feature_columns: An iterable containing all the feature columns
used by the model's DNN.
tree_learner_config: A config for the tree learner.
num_trees: Number of trees to grow model to after training DNN.
tree_examples_per_layer: Number of examples to accumulate before
growing the tree a layer. This value has a big impact on model
quality and should be set equal to the number of examples in
training dataset if possible. It can also be a function that computes
the number of examples based on the depth of the layer that's
being built.
config: `RunConfig` of the estimator.
dnn_optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training the DNN. If `None`, will use the Adagrad
optimizer with default learning rate of 0.001.
dnn_activation_fn: Activation function applied to each layer of the DNN.
If `None`, will use `tf.nn.relu`.
dnn_dropout: When not `None`, the probability to drop out a given
unit in the DNN.
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
as a feature to the tree.
dnn_steps_to_train: Number of steps to train dnn for before switching
to gbdt.
predict_with_tree_only: Whether to use only the tree model output as the
final prediction.
tree_feature_columns: An iterable containing all the feature columns
used by the model's boosted trees. If dnn_input_layer_to_tree is
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
float defines the weight of the distillation loss, and the loss_fn, for
computing distillation loss, takes dnn_logits, tree_logits and weight
tensor. If the entire tuple is None, no distillation will be applied. If
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
loss be default. When distillation is applied, `predict_with_tree_only`
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if inputs are not valid.
"""
if not isinstance(features, dict):
raise ValueError("features should be a dictionary of `Tensor`s. "
"Given type: {}".format(type(features)))
if not dnn_feature_columns:
raise ValueError("dnn_feature_columns must be specified")
if dnn_to_tree_distillation_param:
if not predict_with_tree_only:
logging.warning("update predict_with_tree_only to True since distillation"
"is specified.")
predict_with_tree_only = True
# Build DNN Logits.
dnn_parent_scope = "dnn"
dnn_partitioner = dnn_input_layer_partitioner or (
partitioned_variables.min_max_variable_partitioner(
max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))
if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and
not use_core_versions):
raise ValueError("You must use core versions with Estimator Spec")
global_step = training_util.get_global_step()
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=dnn_partitioner):
with variable_scope.variable_scope(
"input_from_feature_columns",
values=tuple(six.itervalues(features)),
partitioner=dnn_partitioner) as input_layer_scope:
if use_core_versions:
input_layer = feature_column_lib.input_layer(
features=features,
feature_columns=dnn_feature_columns,
weight_collections=[dnn_parent_scope])
else:
input_layer = layers.input_from_feature_columns(
columns_to_tensors=features,
feature_columns=dnn_feature_columns,
weight_collections=[dnn_parent_scope],
scope=input_layer_scope)
def dnn_logits_fn():
"""Builds the logits from the input layer."""
previous_layer = input_layer
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
with variable_scope.variable_scope(
"hiddenlayer_%d" % layer_id,
values=(previous_layer,)) as hidden_layer_scope:
net = layers.fully_connected(
previous_layer,
num_hidden_units,
activation_fn=dnn_activation_fn,
variables_collections=[dnn_parent_scope],
scope=hidden_layer_scope)
if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN:
net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout))
_add_hidden_layer_summary(net, hidden_layer_scope.name)
previous_layer = net
with variable_scope.variable_scope(
"logits", values=(previous_layer,)) as logits_scope:
dnn_logits = layers.fully_connected(
previous_layer,
head.logits_dimension,
activation_fn=None,
variables_collections=[dnn_parent_scope],
scope=logits_scope)
_add_hidden_layer_summary(dnn_logits, logits_scope.name)
return dnn_logits
if predict_with_tree_only and mode == model_fn.ModeKeys.INFER:
dnn_logits = array_ops.constant(0.0)
dnn_train_op_fn = control_flow_ops.no_op
elif predict_with_tree_only and mode == model_fn.ModeKeys.EVAL:
dnn_logits = control_flow_ops.cond(
global_step > dnn_steps_to_train,
lambda: array_ops.constant(0.0),
dnn_logits_fn)
dnn_train_op_fn = control_flow_ops.no_op
else:
dnn_logits = dnn_logits_fn()
def dnn_train_op_fn(loss):
"""Returns the op to optimize the loss."""
return optimizers.optimize_loss(
loss=loss,
global_step=training_util.get_global_step(),
learning_rate=_DNN_LEARNING_RATE,
optimizer=_get_optimizer(dnn_optimizer),
name=dnn_parent_scope,
variables=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope),
# Empty summaries to prevent optimizers from logging training_loss.
summaries=[])
# Build Tree Logits.
with ops.device(global_step.device):
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0,
tree_ensemble_config="", # Initialize an empty ensemble.
name="ensemble_model")
tree_features = features.copy()
if dnn_input_layer_to_tree:
tree_features["dnn_input_layer"] = input_layer
tree_feature_columns.append(layers.real_valued_column("dnn_input_layer"))
gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
is_chief=config.is_chief,
num_ps_replicas=config.num_ps_replicas,
ensemble_handle=ensemble_handle,
center_bias=tree_center_bias,
examples_per_layer=tree_examples_per_layer,
learner_config=tree_learner_config,
feature_columns=tree_feature_columns,
logits_dimension=head.logits_dimension,
features=tree_features,
use_core_columns=use_core_versions)
with ops.name_scope("gbdt"):
predictions_dict = gbdt_model.predict(mode)
tree_logits = predictions_dict["predictions"]
def _tree_train_op_fn(loss):
"""Returns the op to optimize the loss."""
if dnn_to_tree_distillation_param:
loss_weight, loss_fn = dnn_to_tree_distillation_param
# pylint: disable=protected-access
if use_core_versions:
weight_tensor = head_lib._weight_tensor(features, head._weight_column)
else:
weight_tensor = head_lib._weight_tensor(
features, head.weight_column_name)
# pylint: enable=protected-access
dnn_logits_fixed = array_ops.stop_gradient(dnn_logits)
if loss_fn is None:
# we create the loss_fn similar to the head loss_fn for
# multi_class_head used previously as the default one.
n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension
loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn(
n_classes)
dnn_to_tree_distillation_loss = loss_weight * loss_fn(
dnn_logits_fixed, tree_logits, weight_tensor)
summary.scalar("dnn_to_tree_distillation_loss",
dnn_to_tree_distillation_loss)
loss += dnn_to_tree_distillation_loss
update_op = gbdt_model.train(loss, predictions_dict, labels)
with ops.control_dependencies(
[update_op]), (ops.colocate_with(global_step)):
update_op = state_ops.assign_add(global_step, 1).op
return update_op
if predict_with_tree_only:
if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER:
tree_train_logits = tree_logits
else:
tree_train_logits = control_flow_ops.cond(
global_step > dnn_steps_to_train,
lambda: tree_logits,
lambda: dnn_logits)
else:
tree_train_logits = dnn_logits + tree_logits
def _no_train_op_fn(loss):
"""Returns a no-op."""
del loss
return control_flow_ops.no_op()
if tree_center_bias:
num_trees += 1
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS:
model_fn_ops = head.create_model_fn_ops(
features=features,
mode=mode,
labels=labels,
train_op_fn=_no_train_op_fn,
logits=tree_train_logits)
if mode != model_fn.ModeKeys.TRAIN:
return model_fn_ops
dnn_train_op = head.create_model_fn_ops(
features=features,
mode=mode,
labels=labels,
train_op_fn=dnn_train_op_fn,
logits=dnn_logits).train_op
tree_train_op = head.create_model_fn_ops(
features=tree_features,
mode=mode,
labels=labels,
train_op_fn=_tree_train_op_fn,
logits=tree_train_logits).train_op
# Add the hooks
model_fn_ops.training_hooks.extend([
trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
tree_train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees,
override_global_step_value)
])
return model_fn_ops
elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC:
fusion_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_no_train_op_fn,
logits=tree_train_logits)
if mode != model_fn.ModeKeys.TRAIN:
return fusion_spec
dnn_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=dnn_train_op_fn,
logits=dnn_logits)
tree_spec = head.create_estimator_spec(
features=tree_features,
mode=mode,
labels=labels,
train_op_fn=_tree_train_op_fn,
logits=tree_train_logits)
training_hooks = [
trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
tree_spec.train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees,
override_global_step_value)
]
fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
list(fusion_spec.training_hooks))
return fusion_spec
class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
"""A classifier that uses a combined DNN/GBDT model."""
def __init__(self,
dnn_hidden_units,
dnn_feature_columns,
tree_learner_config,
num_trees,
tree_examples_per_layer,
n_classes=2,
weight_column_name=None,
model_dir=None,
config=None,
label_name=None,
label_keys=None,
feature_engineering_fn=None,
dnn_optimizer="Adagrad",
dnn_activation_fn=nn.relu,
dnn_dropout=None,
dnn_input_layer_partitioner=None,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=10000,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
Args:
dnn_hidden_units: List of hidden units per layer for DNN.
dnn_feature_columns: An iterable containing all the feature columns
used by the model's DNN.
tree_learner_config: A config for the tree learner.
num_trees: Number of trees to grow model to after training DNN.
tree_examples_per_layer: Number of examples to accumulate before
growing the tree a layer. This value has a big impact on model
quality and should be set equal to the number of examples in
training dataset if possible. It can also be a function that computes
the number of examples based on the depth of the layer that's
being built.
n_classes: The number of label classes.
weight_column_name: The name of weight column.
model_dir: Directory for model exports.
config: `RunConfig` of the estimator.
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
dnn_optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training the DNN. If `None`, will use the Adagrad
optimizer with default learning rate.
dnn_activation_fn: Activation function applied to each layer of the DNN.
If `None`, will use `tf.nn.relu`.
dnn_dropout: When not `None`, the probability to drop out a given
unit in the DNN.
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
Defaults to `min_max_variable_partitioner` with `min_slice_size`
64 << 20.
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
as a feature to the tree.
dnn_steps_to_train: Number of steps to train dnn for before switching
to gbdt.
predict_with_tree_only: Whether to use only the tree model output as the
final prediction.
tree_feature_columns: An iterable containing all the feature columns
used by the model's boosted trees. If dnn_input_layer_to_tree is
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
float defines the weight of the distillation loss, and the loss_fn, for
computing distillation loss, takes dnn_logits, tree_logits and weight
tensor. If the entire tuple is None, no distillation will be applied. If
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
loss be default. When distillation is applied, `predict_with_tree_only`
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
label_name=label_name,
label_keys=label_keys,
weight_column_name=weight_column_name,
enable_centered_bias=False)
def _model_fn(features, labels, mode, config):
return _dnn_tree_combined_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
dnn_hidden_units=dnn_hidden_units,
dnn_feature_columns=dnn_feature_columns,
tree_learner_config=tree_learner_config,
num_trees=num_trees,
tree_examples_per_layer=tree_examples_per_layer,
config=config,
dnn_optimizer=dnn_optimizer,
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
dnn_steps_to_train=dnn_steps_to_train,
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions,
override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedClassifier, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
"""A regressor that uses a combined DNN/GBDT model."""
def __init__(self,
dnn_hidden_units,
dnn_feature_columns,
tree_learner_config,
num_trees,
tree_examples_per_layer,
weight_column_name=None,
model_dir=None,
config=None,
label_name=None,
label_dimension=1,
feature_engineering_fn=None,
dnn_optimizer="Adagrad",
dnn_activation_fn=nn.relu,
dnn_dropout=None,
dnn_input_layer_partitioner=None,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=10000,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
Args:
dnn_hidden_units: List of hidden units per layer for DNN.
dnn_feature_columns: An iterable containing all the feature columns
used by the model's DNN.
tree_learner_config: A config for the tree learner.
num_trees: Number of trees to grow model to after training DNN.
tree_examples_per_layer: Number of examples to accumulate before
growing the tree a layer. This value has a big impact on model
quality and should be set equal to the number of examples in
training dataset if possible. It can also be a function that computes
the number of examples based on the depth of the layer that's
being built.
weight_column_name: The name of weight column.
model_dir: Directory for model exports.
config: `RunConfig` of the estimator.
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
label_dimension: Number of regression labels per example. This is the size
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
dnn_optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training the DNN. If `None`, will use the Adagrad
optimizer with default learning rate.
dnn_activation_fn: Activation function applied to each layer of the DNN.
If `None`, will use `tf.nn.relu`.
dnn_dropout: When not `None`, the probability to drop out a given
unit in the DNN.
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
Defaults to `min_max_variable_partitioner` with `min_slice_size`
64 << 20.
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
as a feature to the tree.
dnn_steps_to_train: Number of steps to train dnn for before switching
to gbdt.
predict_with_tree_only: Whether to use only the tree model output as the
final prediction.
tree_feature_columns: An iterable containing all the feature columns
used by the model's boosted trees. If dnn_input_layer_to_tree is
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
float defines the weight of the distillation loss, and the loss_fn, for
computing distillation loss, takes dnn_logits, tree_logits and weight
tensor. If the entire tuple is None, no distillation will be applied. If
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
loss be default. When distillation is applied, `predict_with_tree_only`
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
label_dimension=label_dimension,
weight_column_name=weight_column_name,
enable_centered_bias=False)
# num_classes needed for GradientBoostedDecisionTreeModel
if label_dimension == 1:
tree_learner_config.num_classes = 2
else:
tree_learner_config.num_classes = label_dimension
def _model_fn(features, labels, mode, config):
return _dnn_tree_combined_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
dnn_hidden_units=dnn_hidden_units,
dnn_feature_columns=dnn_feature_columns,
tree_learner_config=tree_learner_config,
num_trees=num_trees,
tree_examples_per_layer=tree_examples_per_layer,
config=config,
dnn_optimizer=dnn_optimizer,
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
dnn_steps_to_train=dnn_steps_to_train,
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions,
override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
"""An estimator that uses a combined DNN/GBDT model.
Useful for training with user specified `Head`.
"""
def __init__(self,
dnn_hidden_units,
dnn_feature_columns,
tree_learner_config,
num_trees,
tree_examples_per_layer,
head,
model_dir=None,
config=None,
feature_engineering_fn=None,
dnn_optimizer="Adagrad",
dnn_activation_fn=nn.relu,
dnn_dropout=None,
dnn_input_layer_partitioner=None,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=10000,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
Args:
dnn_hidden_units: List of hidden units per layer for DNN.
dnn_feature_columns: An iterable containing all the feature columns
used by the model's DNN.
tree_learner_config: A config for the tree learner.
num_trees: Number of trees to grow model to after training DNN.
tree_examples_per_layer: Number of examples to accumulate before
growing the tree a layer. This value has a big impact on model
quality and should be set equal to the number of examples in
training dataset if possible. It can also be a function that computes
the number of examples based on the depth of the layer that's
being built.
head: `Head` instance.
model_dir: Directory for model exports.
config: `RunConfig` of the estimator.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
dnn_optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training the DNN. If `None`, will use the Adagrad
optimizer with default learning rate.
dnn_activation_fn: Activation function applied to each layer of the DNN.
If `None`, will use `tf.nn.relu`.
dnn_dropout: When not `None`, the probability to drop out a given
unit in the DNN.
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
Defaults to `min_max_variable_partitioner` with `min_slice_size`
64 << 20.
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
as a feature to the tree.
dnn_steps_to_train: Number of steps to train dnn for before switching
to gbdt.
predict_with_tree_only: Whether to use only the tree model output as the
final prediction.
tree_feature_columns: An iterable containing all the feature columns
used by the model's boosted trees. If dnn_input_layer_to_tree is
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
float defines the weight of the distillation loss, and the loss_fn, for
computing distillation loss, takes dnn_logits, tree_logits and weight
tensor. If the entire tuple is None, no distillation will be applied. If
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
loss be default. When distillation is applied, `predict_with_tree_only`
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
"""
def _model_fn(features, labels, mode, config):
return _dnn_tree_combined_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
dnn_hidden_units=dnn_hidden_units,
dnn_feature_columns=dnn_feature_columns,
tree_learner_config=tree_learner_config,
num_trees=num_trees,
tree_examples_per_layer=tree_examples_per_layer,
config=config,
dnn_optimizer=dnn_optimizer,
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
dnn_steps_to_train=dnn_steps_to_train,
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions,
override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn,
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
"""Initializes a core version of DNNBoostedTreeCombinedEstimator.
Args:
dnn_hidden_units: List of hidden units per layer for DNN.
dnn_feature_columns: An iterable containing all the feature columns
used by the model's DNN.
tree_learner_config: A config for the tree learner.
num_trees: Number of trees to grow model to after training DNN.
tree_examples_per_layer: Number of examples to accumulate before
growing the tree a layer. This value has a big impact on model
quality and should be set equal to the number of examples in
training dataset if possible. It can also be a function that computes
the number of examples based on the depth of the layer that's
being built.
head: `Head` instance.
model_dir: Directory for model exports.
config: `RunConfig` of the estimator.
dnn_optimizer: string, `Optimizer` object, or callable that defines the
optimizer to use for training the DNN. If `None`, will use the Adagrad
optimizer with default learning rate.
dnn_activation_fn: Activation function applied to each layer of the DNN.
If `None`, will use `tf.nn.relu`.
dnn_dropout: When not `None`, the probability to drop out a given
unit in the DNN.
dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
Defaults to `min_max_variable_partitioner` with `min_slice_size`
64 << 20.
dnn_input_layer_to_tree: Whether to provide the DNN's input layer
as a feature to the tree.
dnn_steps_to_train: Number of steps to train dnn for before switching
to gbdt.
predict_with_tree_only: Whether to use only the tree model output as the
final prediction.
tree_feature_columns: An iterable containing all the feature columns
used by the model's boosted trees. If dnn_input_layer_to_tree is
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
float defines the weight of the distillation loss, and the loss_fn, for
computing distillation loss, takes dnn_logits, tree_logits and weight
tensor. If the entire tuple is None, no distillation will be applied. If
only the loss_fn is None, we will take the sigmoid/softmax cross entropy
loss be default. When distillation is applied, `predict_with_tree_only`
will be set to True.
"""
def __init__(self,
dnn_hidden_units,
dnn_feature_columns,
tree_learner_config,
num_trees,
tree_examples_per_layer,
head,
model_dir=None,
config=None,
dnn_optimizer="Adagrad",
dnn_activation_fn=nn.relu,
dnn_dropout=None,
dnn_input_layer_partitioner=None,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=10000,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None):
def _model_fn(features, labels, mode, config):
return _dnn_tree_combined_model_fn(
features=features,
labels=labels,
mode=mode,
head=head,
dnn_hidden_units=dnn_hidden_units,
dnn_feature_columns=dnn_feature_columns,
tree_learner_config=tree_learner_config,
num_trees=num_trees,
tree_examples_per_layer=tree_examples_per_layer,
config=config,
dnn_optimizer=dnn_optimizer,
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
dnn_input_layer_partitioner=dnn_input_layer_partitioner,
dnn_input_layer_to_tree=dnn_input_layer_to_tree,
dnn_steps_to_train=dnn_steps_to_train,
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
use_core_versions=True,
override_global_step_value=None)
super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)

View File

@ -1,249 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for combined DNN + GBDT estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.python.estimator import exporter
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.export import export
from tensorflow.python.ops import parsing_ops
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
features = {
"x": constant_op.constant([[2.], [1.], [1.]])
}
label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32)
return features, label
def _eval_input_fn():
features = {
"x": constant_op.constant([[1.], [2.], [2.]])
}
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
return features, label
class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
def testClassifierContract(self):
estimator_test_utils.assert_estimator_contract(
self, estimator.DNNBoostedTreeCombinedClassifier)
def testRegressorContract(self):
estimator_test_utils.assert_estimator_contract(
self, estimator.DNNBoostedTreeCombinedRegressor)
def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(
self, estimator.DNNBoostedTreeCombinedEstimator)
def testNoDNNFeatureColumns(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
with self.assertRaisesRegexp(
ValueError,
"dnn_feature_columns must be specified"):
classifier = estimator.DNNBoostedTreeCombinedClassifier(
dnn_hidden_units=[1],
dnn_feature_columns=[],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
n_classes=2)
classifier.fit(input_fn=_train_input_fn, steps=5)
def testFitAndEvaluateDontThrowException(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.constraints.max_tree_depth = 1
model_dir = tempfile.mkdtemp()
config = run_config.RunConfig()
classifier = estimator.DNNBoostedTreeCombinedClassifier(
dnn_hidden_units=[1],
dnn_feature_columns=[feature_column.real_valued_column("x")],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
n_classes=2,
model_dir=model_dir,
config=config,
dnn_steps_to_train=10,
dnn_input_layer_to_tree=False,
tree_feature_columns=[feature_column.real_valued_column("x")])
classifier.fit(input_fn=_train_input_fn, steps=15)
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
def testFitAndEvaluateWithDistillation(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.constraints.max_tree_depth = 1
model_dir = tempfile.mkdtemp()
config = run_config.RunConfig()
classifier = estimator.DNNBoostedTreeCombinedClassifier(
dnn_hidden_units=[1],
dnn_feature_columns=[feature_column.real_valued_column("x")],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
n_classes=2,
model_dir=model_dir,
config=config,
dnn_steps_to_train=10,
dnn_input_layer_to_tree=False,
tree_feature_columns=[feature_column.real_valued_column("x")],
dnn_to_tree_distillation_param=(1, None))
classifier.fit(input_fn=_train_input_fn, steps=15)
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
def _assert_checkpoint(self, model_dir, global_step):
reader = checkpoint_utils.load_checkpoint(model_dir)
self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self):
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.constraints.max_tree_depth = 3
model_dir = tempfile.mkdtemp()
config = run_config.RunConfig()
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
head=head_fn,
dnn_hidden_units=[1],
dnn_feature_columns=[core_feature_column.numeric_column("x")],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
model_dir=model_dir,
config=config,
dnn_steps_to_train=10,
dnn_input_layer_to_tree=False,
tree_feature_columns=[core_feature_column.numeric_column("x")])
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
# 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
self._assert_checkpoint(est.model_dir, global_step=14)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self):
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.constraints.max_tree_depth = 3
model_dir = tempfile.mkdtemp()
config = run_config.RunConfig()
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
head=head_fn,
dnn_hidden_units=[1],
dnn_feature_columns=[core_feature_column.numeric_column("x")],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
model_dir=model_dir,
config=config,
dnn_steps_to_train=10,
dnn_input_layer_to_tree=True,
tree_feature_columns=[])
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
def testTrainEvaluateWithDnnForInputAndTreeForPredict(self):
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
learner_config.constraints.max_tree_depth = 3
model_dir = tempfile.mkdtemp()
config = run_config.RunConfig()
est = estimator.CoreDNNBoostedTreeCombinedEstimator(
head=head_fn,
dnn_hidden_units=[1],
dnn_feature_columns=[core_feature_column.numeric_column("x")],
tree_learner_config=learner_config,
num_trees=1,
tree_examples_per_layer=3,
model_dir=model_dir,
config=config,
dnn_steps_to_train=10,
dnn_input_layer_to_tree=True,
predict_with_tree_only=True,
dnn_to_tree_distillation_param=(0.5, None),
tree_feature_columns=[])
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
serving_input_fn = (
export.build_parsing_serving_input_receiver_fn(
feature_spec={"x": parsing_ops.FixedLenFeature(
[1], dtype=dtypes.float32)}))
base_exporter = exporter.FinalExporter(
name="Servo",
serving_input_receiver_fn=serving_input_fn,
assets_extra=None)
export_path = os.path.join(model_dir, "export")
base_exporter.export(
est,
export_path=export_path,
checkpoint_path=None,
eval_result={},
is_the_final_export=True)
if __name__ == "__main__":
googletest.main()

View File

@ -1,837 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GTFlow Estimator definition."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import losses as core_losses
from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head
from tensorflow.python.ops import array_ops
# ================== Old estimator interface===================================
# The estimators below were designed for old feature columns and old estimator
# interface. They can be used with new feature columns and losses by setting
# use_core_libs = True.
class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
"""An estimator using gradient boosted decision trees."""
def __init__(self,
learner_config,
examples_per_layer,
n_classes=2,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
label_keys=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
override_global_step_value=None,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
n_classes: Number of classes in the classification.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
[batch_size, num_trees]. For example, result_iter =
classifier.predict(...)
for result_dict in result_iter: # access leaf index list by
result_dict["leaf_index"] # which contains one leaf index per tree
override_global_step_value: If after the training is done, global step
value must be reset to this value. This should be used to reset global
step to a number > number of steps used to train the current ensemble.
For example, the usual way is to train a number of trees and set a very
large number of training steps. When the training is done (number of
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
"""
if n_classes > 2:
# For multi-class classification, use our loss implementation that
# supports second order derivative.
def loss_fn(labels, logits, weights=None):
result = losses.per_example_maxent_loss(
labels=labels,
logits=logits,
weights=weights,
num_classes=n_classes)
return math_ops.reduce_mean(result[0])
else:
loss_fn = None
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=False,
loss_fn=loss_fn,
label_keys=label_keys)
if learner_config.num_classes == 0:
learner_config.num_classes = n_classes
elif learner_config.num_classes != n_classes:
raise ValueError("n_classes (%d) doesn't match learner_config (%d)." %
(n_classes, learner_config.num_classes))
super(GradientBoostedDecisionTreeClassifier, self).__init__(
model_fn=model.model_builder,
params={
'head': head,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'override_global_step_value': override_global_step_value,
'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
"""An estimator using gradient boosted decision trees."""
def __init__(self,
learner_config,
examples_per_layer,
label_dimension=1,
num_trees=None,
feature_columns=None,
label_name=None,
weight_column_name=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
override_global_step_value=None,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
label_dimension: Number of regression labels per example. This is the size
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
label_name: String, name of the key in label dict. Can be null if label is
a tensor (single headed models).
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. For example, result_dict =
classifier.predict(...)
for example_prediction_result in result_dict: # access leaf index list
by example_prediction_result["leaf_index"] # which contains one leaf
index per tree
override_global_step_value: If after the training is done, global step
value must be reset to this value. This should be used to reset global
step to a number > number of steps used to train the current ensemble.
For example, the usual way is to train a number of trees and set a very
large number of training steps. When the training is done (number of
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
label_dimension=label_dimension,
weight_column_name=weight_column_name,
enable_centered_bias=False)
if label_dimension == 1:
learner_config.num_classes = 2
else:
learner_config.num_classes = label_dimension
super(GradientBoostedDecisionTreeRegressor, self).__init__(
model_fn=model.model_builder,
params={
'head': head,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
'override_global_step_value': override_global_step_value,
'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
"""An estimator using gradient boosted decision trees.
Useful for training with user specified `Head`.
"""
def __init__(self,
learner_config,
examples_per_layer,
head,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
override_global_step_value=None,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
head: `Head` instance.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. For example, result_dict =
classifier.predict(...)
for example_prediction_result in result_dict: # access leaf index list
by example_prediction_result["leaf_index"] # which contains one leaf
index per tree
override_global_step_value: If after the training is done, global step
value must be reset to this value. This should be used to reset global
step to a number > number of steps used to train the current ensemble.
For example, the usual way is to train a number of trees and set a very
large number of training steps. When the training is done (number of
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
params={
'head': head,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
'override_global_step_value': override_global_step_value,
'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
class GradientBoostedDecisionTreeRanker(estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
def __init__(self,
learner_config,
examples_per_layer,
head,
ranking_model_pair_keys,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
label_keys=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
override_global_step_value=None,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
used for inference on non-paired data. This is essentially LambdaMart.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
head: `Head` instance.
ranking_model_pair_keys: Keys to distinguish between features for left and
right part of the training pairs for ranking. For example, for an
Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
[batch_size, num_trees]. For example, result_iter =
classifier.predict(...)
for result_dict in result_iter: # access leaf index list by
result_dict["leaf_index"] # which contains one leaf index per tree
override_global_step_value: If after the training is done, global step
value must be reset to this value. This should be used to reset global
step to a number > number of steps used to train the current ensemble.
For example, the usual way is to train a number of trees and set a very
large number of training steps. When the training is done (number of
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
"""
super(GradientBoostedDecisionTreeRanker, self).__init__(
model_fn=model.ranking_model_builder,
params={
'head': head,
'n_classes': 2,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
'override_global_step_value': override_global_step_value,
'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
# When using this estimator, make sure to regularize the hessian (at least l2,
# min_node_weight)!
# TODO(nponomareva): extend to take multiple quantiles in one go.
class GradientBoostedDecisionTreeQuantileRegressor(estimator.Estimator):
"""An estimator that does quantile regression and returns quantile estimates."""
def __init__(self,
learner_config,
examples_per_layer,
quantiles,
label_dimension=1,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
override_global_step_value=None,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeQuantileRegressor instance.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
quantiles: a list of quantiles for the loss, each between 0 and 1.
label_dimension: Dimension of regression label. This is the size of the
last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`). When label_dimension>1, it is
recommended to use multiclass strategy diagonal hessian or full hessian.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. For example, result_dict =
classifier.predict(...)
for example_prediction_result in result_dict: # access leaf index list
by example_prediction_result["leaf_index"] # which contains one leaf
index per tree
override_global_step_value: If after the training is done, global step
value must be reset to this value. This should be used to reset global
step to a number > number of steps used to train the current ensemble.
For example, the usual way is to train a number of trees and set a very
large number of training steps. When the training is done (number of
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
num_quantiles: Number of quantiles to build for numeric feature values.
"""
if len(quantiles) > 1:
raise ValueError('For now, just one quantile per estimator is supported')
def _quantile_regression_head(quantile):
# Use quantile regression.
head = custom_loss_head.CustomLossHead(
loss_fn=functools.partial(
losses.per_example_quantile_regression_loss, quantile=quantile),
link_fn=array_ops.identity,
logit_dimension=label_dimension)
return head
learner_config.num_classes = max(2, label_dimension)
super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__(
model_fn=model.model_builder,
params={
'head': _quantile_regression_head(quantiles[0]),
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
'override_global_step_value': override_global_step_value,
'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
# ================== New Estimator interface===================================
# The estimators below use new core Estimator interface and must be used with
# new feature columns and heads.
# For multiclass classification, use the following head since it uses loss
# that is twice differentiable.
def core_multiclass_head(
n_classes,
weight_column=None,
loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
"""Core head for multiclass problems."""
def loss_fn(labels, logits):
result = losses.per_example_maxent_loss(
# Don't pass the weights: head already multiplies by them.
labels=labels, logits=logits, weights=None, num_classes=n_classes)
return result[0]
# pylint:disable=protected-access
head_fn = core_head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=n_classes,
loss_fn=loss_fn,
loss_reduction=loss_reduction,
weight_column=weight_column)
# pylint:enable=protected-access
return head_fn
# For quantile regression, use this head with Core..Estimator, or use
# Core..QuantileRegressor directly,
def core_quantile_regression_head(
quantiles,
label_dimension=1,
weight_column=None,
loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
"""Core head for quantile regression problems."""
def loss_fn(labels, logits):
result = losses.per_example_quantile_regression_loss(
labels=labels,
predictions=logits,
# Don't pass the weights: head already multiplies by them.
weights=None,
quantile=quantiles)
return result[0]
# pylint:disable=protected-access
head_fn = core_head_lib._regression_head(
label_dimension=label_dimension,
loss_fn=loss_fn,
loss_reduction=loss_reduction,
weight_column=weight_column)
# pylint:enable=protected-access
return head_fn
class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
"""An estimator using gradient boosted decision trees.
Useful for training with user specified `Head`.
"""
def __init__(self,
learner_config,
examples_per_layer,
head,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
label_keys=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
output_leaf_index=False,
num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
head: `Head` instance.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. For example, result_dict =
classifier.predict(...)
for example_prediction_result in result_dict: # access leaf index list
by example_prediction_result["leaf_index"] # which contains one leaf
index per tree
num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
return model.model_builder(
features=features,
labels=labels,
mode=mode,
config=config,
params={
'head': head,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'override_global_step_value': None,
'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
super(CoreGradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
def __init__(self,
learner_config,
examples_per_layer,
head,
ranking_model_pair_keys,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
label_keys=None,
logits_modifier_function=None,
center_bias=False,
output_leaf_index=False,
num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
used for inference on non-paired data. This is essentially LambdaMart.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
head: `Head` instance.
ranking_model_pair_keys: Keys to distinguish between features for left and
right part of the training pairs for ranking. For example, for an
Example with features "a.f1" and "b.f1", the keys would be ("a", "b").
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
[batch_size, num_trees]. For example, result_iter =
classifier.predict(...)
for result_dict in result_iter: # access leaf index list by
result_dict["leaf_index"] # which contains one leaf index per tree
num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
"""
def _model_fn(features, labels, mode, config):
return model.ranking_model_builder(
features=features,
labels=labels,
mode=mode,
config=config,
params={
'head': head,
'n_classes': 2,
'feature_columns': feature_columns,
'learner_config': learner_config,
'num_trees': num_trees,
'weight_column_name': weight_column_name,
'examples_per_layer': examples_per_layer,
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
'override_global_step_value': None,
'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
super(CoreGradientBoostedDecisionTreeRanker, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
# When using this estimator, make sure to regularize the hessian (at least l2,
# min_node_weight)!
# TODO(nponomareva): extend to take multiple quantiles in one go.
class CoreGradientBoostedDecisionTreeQuantileRegressor(
core_estimator.Estimator):
"""An estimator that does quantile regression and returns quantile estimates."""
def __init__(self,
learner_config,
examples_per_layer,
quantiles,
label_dimension=1,
num_trees=None,
feature_columns=None,
weight_column_name=None,
model_dir=None,
config=None,
label_keys=None,
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
output_leaf_index=False,
num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
learner_config: A config for the learner.
examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
quantiles: a list of quantiles for the loss, each between 0 and 1.
label_dimension: Dimension of regression label. This is the size of the
last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`). When label_dimension>1, it is
recommended to use multiclass strategy diagonal hessian or full hessian.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
weighted.
model_dir: Directory for model exports, etc.
config: `RunConfig` object to configure the runtime settings.
label_keys: Optional list of strings with size `[n_classes]` defining the
label vocabulary. Only supported for `n_classes` > 2.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
output_leaf_index: whether to output leaf indices along with predictions
during inference. The leaf node indexes are available in predictions
dict by the key 'leaf_index'. For example, result_dict =
classifier.predict(...)
for example_prediction_result in result_dict: # access leaf index list
by example_prediction_result["leaf_index"] # which contains one leaf
index per tree
num_quantiles: Number of quantiles to build for numeric feature values.
"""
if len(quantiles) > 1:
raise ValueError('For now, just one quantile per estimator is supported')
def _model_fn(features, labels, mode, config):
return model.model_builder(
features=features,
labels=labels,
mode=mode,
config=config,
params={
'head':
core_quantile_regression_head(
quantiles[0],
label_dimension=label_dimension,
weight_column=weight_column_name),
'feature_columns':
feature_columns,
'learner_config':
learner_config,
'num_trees':
num_trees,
'weight_column_name':
weight_column_name,
'examples_per_layer':
examples_per_layer,
'center_bias':
center_bias,
'logits_modifier_function':
logits_modifier_function,
'use_core_libs':
True,
'output_leaf_index':
output_leaf_index,
'override_global_step_value':
None,
'num_quantiles':
num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
super(CoreGradientBoostedDecisionTreeQuantileRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)

View File

@ -1,74 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for converting between core and contrib feature columns."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.export import export_output
_CORE_MODE_TO_CONTRIB_MODE_ = {
model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN,
model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL,
model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER
}
def _core_mode_to_contrib_mode(mode):
return _CORE_MODE_TO_CONTRIB_MODE_[mode]
def _export_outputs_to_output_alternatives(export_outputs):
"""Converts EstimatorSpec.export_outputs to output_alternatives.
Args:
export_outputs: export_outputs created by create_estimator_spec.
Returns:
converted output_alternatives.
"""
output = {}
if export_outputs is not None:
for key, value in export_outputs.items():
if isinstance(value, export_output.ClassificationOutput):
exported_predictions = {
prediction_key.PredictionKey.SCORES: value.scores,
prediction_key.PredictionKey.CLASSES: value.classes
}
output[key] = (constants.ProblemType.CLASSIFICATION,
exported_predictions)
return output
return None
def estimator_spec_to_model_fn_ops(estimator_spec, export_alternatives=False):
if export_alternatives:
alternatives = _export_outputs_to_output_alternatives(
estimator_spec.export_outputs)
else:
alternatives = []
return model_fn.ModelFnOps(
mode=_core_mode_to_contrib_mode(estimator_spec.mode),
predictions=estimator_spec.predictions,
loss=estimator_spec.loss,
train_op=estimator_spec.train_op,
eval_metric_ops=estimator_spec.eval_metric_ops,
output_alternatives=alternatives)

View File

@ -1,440 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GTFlow Model definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
from google.protobuf import text_format
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
class ModelBuilderOutputType(object):
MODEL_FN_OPS = 0
ESTIMATOR_SPEC = 1
def model_builder(features,
labels,
mode,
params,
config,
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
labels: Labels used to train on.
mode: Mode we are in. (TRAIN/EVAL/INFER)
params: A dict of hyperparameters.
The following hyperparameters are expected:
* head: A `Head` instance.
* learner_config: A config for the learner.
* feature_columns: An iterable containing all the feature columns used by
the model.
* examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
* weight_column_name: The name of weight column.
* center_bias: Whether a separate tree should be created for first fitting
the bias.
* override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if inputs are not valid.
"""
head = params["head"]
learner_config = params["learner_config"]
examples_per_layer = params["examples_per_layer"]
feature_columns = params["feature_columns"]
weight_column_name = params["weight_column_name"]
num_trees = params["num_trees"]
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
if config is None:
raise ValueError("Missing estimator RunConfig.")
if config.session_config is not None:
session_config = config.session_config
session_config.allow_soft_placement = True
else:
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
config = config.replace(session_config=session_config)
center_bias = params["center_bias"]
if isinstance(features, ops.Tensor):
features = {features.name: features}
# Make a shallow copy of features to ensure downstream usage
# is unaffected by modifications in the model function.
training_features = copy.copy(features)
training_features.pop(weight_column_name, None)
global_step = training_util.get_global_step()
initial_ensemble = ""
if learner_config.each_tree_start.nodes:
if learner_config.each_tree_start_num_layers <= 0:
raise ValueError("You must provide each_tree_start_num_layers.")
num_layers = learner_config.each_tree_start_num_layers
initial_ensemble = """
trees { %s }
tree_weights: 0.1
tree_metadata {
num_tree_weight_updates: 1
num_layers_grown: %d
is_finalized: false
}
""" % (text_format.MessageToString(
learner_config.each_tree_start), num_layers)
tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(initial_ensemble, tree_ensemble_proto)
initial_ensemble = tree_ensemble_proto.SerializeToString()
with ops.device(global_step.device):
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0,
tree_ensemble_config=initial_ensemble, # Initialize the ensemble.
name="ensemble_model")
# Create GBDT model.
gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
is_chief=config.is_chief,
num_ps_replicas=config.num_ps_replicas,
ensemble_handle=ensemble_handle,
center_bias=center_bias,
examples_per_layer=examples_per_layer,
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
output_leaf_index=output_leaf_index,
num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
if logits_modifier_function:
logits = logits_modifier_function(logits, features, mode)
def _train_op_fn(loss):
"""Returns the op to optimize the loss."""
update_op = gbdt_model.train(loss, predictions_dict, labels)
with ops.control_dependencies(
[update_op]), (ops.colocate_with(global_step)):
update_op = state_ops.assign_add(global_step, 1).op
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees,
override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
model_fn_ops = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
model_fn_ops)
else:
model_fn_ops = head.create_model_fn_ops(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
gbdt_batch.LEAF_INDEX]
model_fn_ops.training_hooks.extend(training_hooks)
return model_fn_ops
elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
assert callable(create_estimator_spec_op)
estimator_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
estimator_spec.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
gbdt_batch.LEAF_INDEX]
estimator_spec = estimator_spec._replace(
training_hooks=training_hooks + list(estimator_spec.training_hooks))
return estimator_spec
return model_fn_ops
def ranking_model_builder(features,
labels,
mode,
params,
config,
output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model for ranking.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
labels: Labels used to train on.
mode: Mode we are in. (TRAIN/EVAL/INFER)
params: A dict of hyperparameters.
The following hyperparameters are expected:
* head: A `Head` instance.
* learner_config: A config for the learner.
* feature_columns: An iterable containing all the feature columns used by
the model.
* examples_per_layer: Number of examples to accumulate before growing a
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
* weight_column_name: The name of weight column.
* center_bias: Whether a separate tree should be created for first fitting
the bias.
* ranking_model_pair_keys (Optional): Keys to distinguish between features
for left and right part of the training pairs for ranking. For example,
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
* override_global_step_value: If after the training is done, global step
value must be reset to this value. This is particularly useful for hyper
parameter tuning, which can't recognize early stopping due to the number
of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if inputs are not valid.
"""
head = params["head"]
learner_config = params["learner_config"]
examples_per_layer = params["examples_per_layer"]
feature_columns = params["feature_columns"]
weight_column_name = params["weight_column_name"]
num_trees = params["num_trees"]
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
if config is None:
raise ValueError("Missing estimator RunConfig.")
center_bias = params["center_bias"]
if isinstance(features, ops.Tensor):
features = {features.name: features}
# Make a shallow copy of features to ensure downstream usage
# is unaffected by modifications in the model function.
training_features = copy.copy(features)
training_features.pop(weight_column_name, None)
global_step = training_util.get_global_step()
with ops.device(global_step.device):
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0,
tree_ensemble_config="", # Initialize an empty ensemble.
name="ensemble_model")
# Extract the features.
if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
# For ranking pairwise training, we extract two sets of features.
if len(ranking_model_pair_keys) != 2:
raise ValueError("You must provide keys for ranking.")
left_pair_key = ranking_model_pair_keys[0]
right_pair_key = ranking_model_pair_keys[1]
if left_pair_key is None or right_pair_key is None:
raise ValueError("Both pair keys should be provided for ranking.")
features_1 = {}
features_2 = {}
for name in training_features:
feature = training_features[name]
new_name = name[2:]
if name.startswith(left_pair_key + "."):
features_1[new_name] = feature
else:
assert name.startswith(right_pair_key + ".")
features_2[new_name] = feature
main_features = features_1
supplementary_features = features_2
else:
# For non-ranking or inference ranking, we have only 1 set of features.
main_features = training_features
# Create GBDT model.
gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel(
is_chief=config.is_chief,
num_ps_replicas=config.num_ps_replicas,
ensemble_handle=ensemble_handle,
center_bias=center_bias,
examples_per_layer=examples_per_layer,
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
output_leaf_index=output_leaf_index,
num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
if mode == learn.ModeKeys.INFER:
predictions_dict = gbdt_model_main.predict(mode)
logits = predictions_dict[gbdt_batch.PREDICTIONS]
if logits_modifier_function:
logits = logits_modifier_function(logits, features, mode)
else:
gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel(
is_chief=config.is_chief,
num_ps_replicas=config.num_ps_replicas,
ensemble_handle=ensemble_handle,
center_bias=center_bias,
examples_per_layer=examples_per_layer,
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=supplementary_features,
use_core_columns=use_core_libs,
output_leaf_index=output_leaf_index)
# Logits for train and eval.
if not supplementary_features:
raise ValueError("Features for ranking must be specified.")
predictions_dict_1 = gbdt_model_main.predict(mode)
predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS]
predictions_dict_2 = gbdt_model_supplementary.predict(mode)
predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS]
logits = predictions_1 - predictions_2
if logits_modifier_function:
logits = logits_modifier_function(logits, features, mode)
predictions_dict = predictions_dict_1
predictions_dict[gbdt_batch.PREDICTIONS] = logits
def _train_op_fn(loss):
"""Returns the op to optimize the loss."""
update_op = gbdt_model_main.train(loss, predictions_dict, labels)
with ops.control_dependencies(
[update_op]), (ops.colocate_with(global_step)):
update_op = state_ops.assign_add(global_step, 1).op
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
training_hooks = []
if num_trees:
if center_bias:
num_trees += 1
finalized_trees, attempted_trees = (
gbdt_model_main.get_number_of_trees_tensor())
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees,
override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
model_fn_ops = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
model_fn_ops)
else:
model_fn_ops = head.create_model_fn_ops(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
gbdt_batch.LEAF_INDEX]
model_fn_ops.training_hooks.extend(training_hooks)
return model_fn_ops
elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
assert callable(create_estimator_spec_op)
estimator_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
estimator_spec = estimator_spec._replace(
training_hooks=training_hooks + list(estimator_spec.training_hooks))
return estimator_spec
return model_fn_ops

View File

@ -1,230 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks for use with GTFlow Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.contrib.learn.python.learn import session_run_hook
from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.training.summary_io import SummaryWriterCache
class FeatureImportanceSummarySaver(session_run_hook.SessionRunHook):
"""Hook to save feature importance summaries."""
def __init__(self, model_dir, every_n_steps=1):
"""Create a FeatureImportanceSummarySaver Hook.
This hook creates scalar summaries representing feature importance
for each feature column during training.
Args:
model_dir: model base output directory.
every_n_steps: frequency, in number of steps, for logging summaries.
Raises:
ValueError: If one of the arguments is invalid.
"""
if model_dir is None:
raise ValueError("model dir must be specified.")
self._model_dir = model_dir
self._every_n_steps = every_n_steps
self._last_triggered_step = None
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use FeatureImportanceSummarySaver.")
graph = ops.get_default_graph()
self._feature_names_tensor = graph.get_tensor_by_name(
"gbdt/feature_names:0")
self._feature_usage_counts_tensor = graph.get_tensor_by_name(
"gbdt/feature_usage_counts:0")
self._feature_gains_tensor = graph.get_tensor_by_name(
"gbdt/feature_gains:0")
def before_run(self, run_context):
del run_context # Unused by feature importance summary saver hook.
requests = {
"global_step": self._global_step_tensor,
"feature_names": self._feature_names_tensor,
"feature_usage_counts": self._feature_usage_counts_tensor,
"feature_gains": self._feature_gains_tensor
}
return SessionRunArgs(requests)
def after_run(self, run_context, run_values):
del run_context # Unused by feature importance summary saver hook.
# Read result tensors.
global_step = run_values.results["global_step"]
feature_names = run_values.results["feature_names"]
feature_usage_counts = run_values.results["feature_usage_counts"]
feature_gains = run_values.results["feature_gains"]
# Ensure summaries are logged at desired frequency
if (self._last_triggered_step is not None and
global_step < self._last_triggered_step + self._every_n_steps):
return
# Validate tensors.
if (len(feature_names) != len(feature_usage_counts) or
len(feature_names) != len(feature_gains)):
raise RuntimeError(
"Feature names and importance measures have inconsistent lengths.")
# Compute total usage.
total_usage_count = 0.0
for usage_count in feature_usage_counts:
total_usage_count += usage_count
usage_count_norm = 1.0 / total_usage_count if total_usage_count else 1.0
# Compute total gain.
total_gain = 0.0
for gain in feature_gains:
total_gain += gain
gain_norm = 1.0 / total_gain if total_gain else 1.0
# Output summary for each feature.
self._last_triggered_step = global_step
for (name, usage_count, gain) in zip(feature_names, feature_usage_counts,
feature_gains):
output_dir = os.path.join(self._model_dir, name.decode("utf-8"))
summary_writer = SummaryWriterCache.get(output_dir)
usage_count_summary = Summary(value=[
Summary.Value(
tag="feature_importance/usage_counts", simple_value=usage_count)
])
usage_fraction_summary = Summary(value=[
Summary.Value(
tag="feature_importance/usage_fraction",
simple_value=usage_count * usage_count_norm)
])
summary_writer.add_summary(usage_count_summary, global_step)
summary_writer.add_summary(usage_fraction_summary, global_step)
gains_summary = Summary(value=[
Summary.Value(tag="feature_importance/gains", simple_value=gain)
])
gains_fraction_summary = Summary(value=[
Summary.Value(
tag="feature_importance/gains_fraction",
simple_value=gain * gain_norm)
])
summary_writer.add_summary(gains_summary, global_step)
summary_writer.add_summary(gains_fraction_summary, global_step)
class FeedFnHook(session_run_hook.SessionRunHook):
"""Runs feed_fn and sets the feed_dict accordingly."""
def __init__(self, feed_fn):
self.feed_fn = feed_fn
def before_run(self, run_context):
del run_context # unused by FeedFnHook.
return session_run_hook.SessionRunArgs(fetches=None, feed_dict=self.feed_fn)
class StopAfterNTrees(session_run_hook.SessionRunHook):
"""Stop training after building N full trees."""
def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor,
override_global_step_value=None):
self._num_trees = n
# num_attempted_trees_tensor and num_finalized_trees_tensor are both
# tensors.
self._num_attempted_trees_tensor = num_attempted_trees_tensor
self._num_finalized_trees_tensor = num_finalized_trees_tensor
self._override_global_step_value = override_global_step_value
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created.")
if self._override_global_step_value is not None:
self._override_global_step_op = state_ops.assign(
self._global_step_tensor, self._override_global_step_value)
def before_run(self, run_context):
del run_context # unused by StopTrainingAfterNTrees.
return session_run_hook.SessionRunArgs({
"num_attempted_trees": self._num_attempted_trees_tensor,
"num_finalized_trees": self._num_finalized_trees_tensor,
})
def after_run(self, run_context, run_values):
num_attempted_trees = run_values.results["num_attempted_trees"]
num_finalized_trees = run_values.results["num_finalized_trees"]
assert num_attempted_trees is not None
assert num_finalized_trees is not None
# Stop when the required number of finalized trees is reached, or when we
# try enough times to build a tree but keep failing.
if (num_finalized_trees >= self._num_trees or
num_attempted_trees > 2 * self._num_trees):
logging.info("Requesting stop since we have reached %d trees.",
num_finalized_trees)
if self._override_global_step_value is not None:
logging.info("Overriding global steps value.")
run_context.session.run(self._override_global_step_op)
run_context.request_stop()
class SwitchTrainOp(session_run_hook.SessionRunHook):
"""Hook that switches the train op after specified number of steps.
Hook that replaces the train op depending on the number of steps of training
that have taken place. The first_train_op is used till train_steps steps
are reached. Thereafter the second_train_op is used.
"""
def __init__(self, first_train_op, train_steps, second_train_op):
"""Initializes a `SwitchTrainOp`."""
self._first_train_op = first_train_op
self._second_train_op = second_train_op
self._train_steps = train_steps
def _get_train_op_for_global_step(self, current_step):
"""Gets train_op for current global step."""
if current_step < self._train_steps:
return self._first_train_op
return self._second_train_op
def begin(self):
self._global_step_tensor = training_util.get_global_step()
self._current_train_op = control_flow_ops.no_op()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SwitchTrainOp.")
def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(
{"global_step": self._global_step_tensor,
"train_op": self._current_train_op})
def after_run(self, run_context, run_values):
self._current_train_op = self._get_train_op_for_global_step(
run_values.results["global_step"])

View File

@ -1,76 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for trainer hooks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import monitored_session
class FeatureImportanceSummarySaverTest(test_util.TensorFlowTestCase):
def test_invalid_input(self):
with self.assertRaises(ValueError):
trainer_hooks.FeatureImportanceSummarySaver(model_dir=None)
def test_invalid_graph(self):
# Create inputs.
model_dir = tempfile.mkdtemp()
hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
with ops.Graph().as_default():
# Begin won't be able to find the required tensors in the graph.
_ = variables.get_or_create_global_step()
with self.assertRaises(KeyError):
hook.begin()
def test_run(self):
# Create inputs.
model_dir = tempfile.mkdtemp()
hook = trainer_hooks.FeatureImportanceSummarySaver(model_dir)
with ops.Graph().as_default(), tf_session.Session() as sess:
global_step = variables.get_or_create_global_step()
with ops.name_scope("gbdt"):
constant_op.constant(["featA", "featB"], name="feature_names")
constant_op.constant([0, 2], name="feature_usage_counts")
constant_op.constant([0, 0.8], name="feature_gains")
# Begin finds tensors in the graph.
hook.begin()
sess.run(tf_variables.global_variables_initializer())
# Run hook in a monitored session.
train_op = state_ops.assign_add(global_step, 1)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
hook.end(sess)
# Ensure output summary dirs are created.
self.assertTrue(os.path.exists(os.path.join(model_dir, "featA")))
self.assertTrue(os.path.exists(os.path.join(model_dir, "featB")))
if __name__ == "__main__":
googletest.main()

View File

@ -1,169 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Demonstrates multiclass MNIST TF Boosted trees example.
This example demonstrates how to run experiments with TF Boosted Trees on
a binary dataset. We use digits 4 and 9 from the original MNIST dataset.
Example Usage:
python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \
--output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \
--batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \
--num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1
When training is done, accuracy on eval data is reported. Point tensorboard
to the directory for the run to see how the training progresses:
tensorboard --logdir=/tmp/binary_mnist
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import numpy as np
import tensorflow as tf
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.learn import learn_runner
def get_input_fn(data,
batch_size,
capacity=10000,
min_after_dequeue=3000):
"""Input function over MNIST data."""
# Keep only 4 and 9 digits.
ids = np.where((data.labels == 4) | (data.labels == 9))
images = data.images[ids]
labels = data.labels[ids]
# Make digit 4 label 1, 9 is 0.
labels = labels == 4
def _input_fn():
"""Prepare features and labels."""
images_batch, labels_batch = tf.train.shuffle_batch(
tensors=[images,
labels.astype(np.int32)],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=4)
features_map = {"images": images_batch}
return features_map, labels_batch
return _input_fn
# Main config - creates a TF Boosted Trees Estimator based on flags.
def _get_tfbt(output_dir):
"""Configures TF Boosted Trees estimator based on flags."""
learner_config = learner_pb2.LearnerConfig()
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
learner_config.regularization.l1 = 0.0
learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
learner_config.constraints.max_tree_depth = FLAGS.depth
growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
learner_config.growing_mode = growing_mode
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
# Create a TF Boosted trees estimator that can take in custom loss.
estimator = GradientBoostedDecisionTreeClassifier(
learner_config=learner_config,
examples_per_layer=FLAGS.examples_per_layer,
model_dir=output_dir,
num_trees=FLAGS.num_trees,
center_bias=False,
config=run_config)
return estimator
def _make_experiment_fn(output_dir):
"""Creates experiment for gradient boosted decision trees."""
data = tf.contrib.learn.datasets.mnist.load_mnist()
train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
return tf.contrib.learn.Experiment(
estimator=_get_tfbt(output_dir),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=None,
eval_steps=FLAGS.num_eval_steps,
eval_metrics=None)
def main(unused_argv):
learn_runner.run(
experiment_fn=_make_experiment_fn,
output_dir=FLAGS.output_dir,
schedule="train_and_evaluate")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
# Define the list of flags that users can change.
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Choose the dir for the output.")
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="The batch size for reading data.")
parser.add_argument(
"--eval_batch_size",
type=int,
default=1000,
help="Size of the batch for eval.")
parser.add_argument(
"--num_eval_steps",
type=int,
default=1,
help="The number of steps to run evaluation for.")
# Flags for gradient boosted trees config.
parser.add_argument(
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
parser.add_argument(
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.1,
help="Learning rate (shrinkage weight) with which each new tree is added."
)
parser.add_argument(
"--examples_per_layer",
type=int,
default=1000,
help="Number of examples to accumulate stats for per layer.")
parser.add_argument(
"--num_trees",
type=int,
default=None,
required=True,
help="Number of trees to grow before stopping.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1,171 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Demonstrates a regression on Boston housing data.
This example demonstrates how to run experiments with TF Boosted Trees on
a regression dataset. We split all the data into 20% test and 80% train,
and are using l2 loss and l2 regularization.
Example Usage:
python tensorflow/contrib/boosted_trees/examples/boston.py \
--batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \
--num_eval_steps=1 --num_trees=500 --l2=0.001 \
--vmodule=training_ops=1
When training is done, mean squared error on eval data is reported.
Point tensorboard to the directory for the run to see how the training
progresses:
tensorboard --logdir=/tmp/boston
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn import learn_runner
from tensorflow.python.util import compat
_BOSTON_NUM_FEATURES = 13
# Main config - creates a TF Boosted Trees Estimator based on flags.
def _get_tfbt(output_dir, feature_cols):
"""Configures TF Boosted Trees estimator based on flags."""
learner_config = learner_pb2.LearnerConfig()
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
learner_config.regularization.l1 = 0.0
learner_config.regularization.l2 = FLAGS.l2
learner_config.constraints.max_tree_depth = FLAGS.depth
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
# Create a TF Boosted trees regression estimator.
estimator = GradientBoostedDecisionTreeRegressor(
learner_config=learner_config,
# This should be the number of examples. For large datasets it can be
# larger than the batch_size.
examples_per_layer=FLAGS.batch_size,
feature_columns=feature_cols,
label_dimension=1,
model_dir=output_dir,
num_trees=FLAGS.num_trees,
center_bias=False,
config=run_config)
return estimator
def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float,
num_sparse_int, export_dir, unused_eval_result):
universal_format = custom_export_strategy.convert_to_universal_format(
dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int)
with tf.gfile.GFile(os.path.join(
compat.as_bytes(export_dir), compat.as_bytes("tree_proto")), "w") as f:
f.write(str(universal_format))
def _make_experiment_fn(output_dir):
"""Creates experiment for gradient boosted decision trees."""
(x_train, y_train), (x_test,
y_test) = tf.keras.datasets.boston_housing.load_data()
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train,
batch_size=FLAGS.batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
feature_columns = [
feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES)
]
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
feature_columns)
serving_input_fn = tf.contrib.learn.utils.build_parsing_serving_input_fn(
feature_spec)
# An export strategy that outputs the feature importance and also exports
# the internal tree representation in another format.
export_strategy = custom_export_strategy.make_custom_export_strategy(
"exports",
convert_fn=_convert_fn,
feature_columns=feature_columns,
export_input_fn=serving_input_fn)
return tf.contrib.learn.Experiment(
estimator=_get_tfbt(output_dir, feature_columns),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=None,
eval_steps=FLAGS.num_eval_steps,
eval_metrics=None,
export_strategies=[export_strategy])
def main(unused_argv):
learn_runner.run(
experiment_fn=_make_experiment_fn,
output_dir=FLAGS.output_dir,
schedule="train_and_evaluate")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
# Define the list of flags that users can change.
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="The batch size for reading data.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Choose the dir for the output.")
parser.add_argument(
"--num_eval_steps",
type=int,
default=1,
help="The number of steps to run evaluation for.")
# Flags for gradient boosted trees config.
parser.add_argument(
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
parser.add_argument(
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.1,
help="Learning rate (shrinkage weight) with which each new tree is added."
)
parser.add_argument(
"--num_trees",
type=int,
default=None,
required=True,
help="Number of trees to grow before stopping.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1,165 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Regression on Boston housing data using DNNBoostedTreeCombinedRegressor.
Example Usage:
python tensorflow/contrib/boosted_trees/examples/boston_combined.py \
--batch_size=404 --output_dir="/tmp/boston" \
--dnn_hidden_units="8,4" --dnn_steps_to_train=1000 \
--tree_depth=4 --tree_learning_rate=0.1 \
--num_trees=100 --tree_l2=0.001 --num_eval_steps=1 \
--vmodule=training_ops=1
When training is done, mean squared error on eval data is reported.
Point tensorboard to the directory for the run to see how the training
progresses:
tensorboard --logdir=/tmp/boston
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.contrib.boosted_trees.estimator_batch.dnn_tree_combined_estimator import DNNBoostedTreeCombinedRegressor
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
_BOSTON_NUM_FEATURES = 13
def _get_estimator(output_dir, feature_cols):
"""Configures DNNBoostedTreeCombinedRegressor based on flags."""
learner_config = learner_pb2.LearnerConfig()
learner_config.learning_rate_tuner.fixed.learning_rate = (
FLAGS.tree_learning_rate)
learner_config.regularization.l1 = 0.0
learner_config.regularization.l2 = FLAGS.tree_l2
learner_config.constraints.max_tree_depth = FLAGS.tree_depth
run_config = tf.contrib.learn.RunConfig(save_summary_steps=1)
# Create a DNNBoostedTreeCombinedRegressor estimator.
estimator = DNNBoostedTreeCombinedRegressor(
dnn_hidden_units=[int(x) for x in FLAGS.dnn_hidden_units.split(",")],
dnn_feature_columns=feature_cols,
tree_learner_config=learner_config,
num_trees=FLAGS.num_trees,
# This should be the number of examples. For large datasets it can be
# larger than the batch_size.
tree_examples_per_layer=FLAGS.batch_size,
model_dir=output_dir,
config=run_config,
dnn_input_layer_to_tree=True,
dnn_steps_to_train=FLAGS.dnn_steps_to_train)
return estimator
def _make_experiment_fn(output_dir):
"""Creates experiment for DNNBoostedTreeCombinedRegressor."""
(x_train, y_train), (x_test,
y_test) = tf.keras.datasets.boston_housing.load_data()
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train,
batch_size=FLAGS.batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
feature_columns = [
feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES)
]
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
export_strategies = [
saved_model_export_utils.make_export_strategy(serving_input_fn)]
return tf.contrib.learn.Experiment(
estimator=_get_estimator(output_dir, feature_columns),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=None,
eval_steps=FLAGS.num_eval_steps,
eval_metrics=None,
export_strategies=export_strategies)
def main(unused_argv):
learn_runner.run(
experiment_fn=_make_experiment_fn,
output_dir=FLAGS.output_dir,
schedule="train_and_evaluate")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
# Define the list of flags that users can change.
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="The batch size for reading data.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Choose the dir for the output.")
parser.add_argument(
"--num_eval_steps",
type=int,
default=1,
help="The number of steps to run evaluation for.")
# Flags for configuring DNNBoostedTreeCombinedRegressor.
parser.add_argument(
"--dnn_hidden_units",
type=str,
default="8,4",
help="Hidden layers for DNN.")
parser.add_argument(
"--dnn_steps_to_train",
type=int,
default=1000,
help="Number of steps to train DNN.")
parser.add_argument(
"--tree_depth", type=int, default=4, help="Maximum depth of trees.")
parser.add_argument(
"--tree_l2", type=float, default=1.0, help="l2 regularization per batch.")
parser.add_argument(
"--tree_learning_rate",
type=float,
default=0.1,
help=("Learning rate (shrinkage weight) with which each "
"new tree is added."))
parser.add_argument(
"--num_trees",
type=int,
default=None,
required=True,
help="Number of trees to grow before stopping.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1,171 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Demonstrates multiclass MNIST TF Boosted trees example.
This example demonstrates how to run experiments with TF Boosted Trees on
a MNIST dataset. We are using layer by layer boosting with diagonal hessian
strategy for multiclass handling, and cross entropy loss.
Example Usage:
python tensorflow/contrib/boosted_trees/examples/mnist.py \
--output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000 \
--examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \
--num_trees=10 --l2=1 --vmodule=training_ops=1
When training is done, accuracy on eval data is reported. Point tensorboard
to the directory for the run to see how the training progresses:
tensorboard --logdir=/tmp/mnist
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import numpy as np
import tensorflow as tf
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.learn import learn_runner
def get_input_fn(dataset_split,
batch_size,
capacity=10000,
min_after_dequeue=3000):
"""Input function over MNIST data."""
def _input_fn():
"""Prepare features and labels."""
images_batch, labels_batch = tf.train.shuffle_batch(
tensors=[dataset_split.images,
dataset_split.labels.astype(np.int32)],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=4)
features_map = {"images": images_batch}
return features_map, labels_batch
return _input_fn
# Main config - creates a TF Boosted Trees Estimator based on flags.
def _get_tfbt(output_dir):
"""Configures TF Boosted Trees estimator based on flags."""
learner_config = learner_pb2.LearnerConfig()
num_classes = 10
learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate
learner_config.num_classes = num_classes
learner_config.regularization.l1 = 0.0
learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer
learner_config.constraints.max_tree_depth = FLAGS.depth
growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
learner_config.growing_mode = growing_mode
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
# Create a TF Boosted trees estimator that can take in custom loss.
estimator = GradientBoostedDecisionTreeClassifier(
learner_config=learner_config,
n_classes=num_classes,
examples_per_layer=FLAGS.examples_per_layer,
model_dir=output_dir,
num_trees=FLAGS.num_trees,
center_bias=False,
config=run_config)
return estimator
def _make_experiment_fn(output_dir):
"""Creates experiment for gradient boosted decision trees."""
data = tf.contrib.learn.datasets.mnist.load_mnist()
train_input_fn = get_input_fn(data.train, FLAGS.batch_size)
eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size)
return tf.contrib.learn.Experiment(
estimator=_get_tfbt(output_dir),
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=None,
eval_steps=FLAGS.num_eval_steps,
eval_metrics=None)
def main(unused_argv):
learn_runner.run(
experiment_fn=_make_experiment_fn,
output_dir=FLAGS.output_dir,
schedule="train_and_evaluate")
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
parser = argparse.ArgumentParser()
# Define the list of flags that users can change.
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Choose the dir for the output.")
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="The batch size for reading data.")
parser.add_argument(
"--eval_batch_size",
type=int,
default=1000,
help="Size of the batch for eval.")
parser.add_argument(
"--num_eval_steps",
type=int,
default=1,
help="The number of steps to run evaluation for.")
# Flags for gradient boosted trees config.
parser.add_argument(
"--depth", type=int, default=4, help="Maximum depth of weak learners.")
parser.add_argument(
"--l2", type=float, default=1.0, help="l2 regularization per batch.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.1,
help="Learning rate (shrinkage weight) with which each new tree is added."
)
parser.add_argument(
"--examples_per_layer",
type=int,
default=1000,
help="Number of examples to accumulate stats for per layer.")
parser.add_argument(
"--num_trees",
type=int,
default=None,
required=True,
help="Number of trees to grow before stopping.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -1,212 +0,0 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <string>
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace boosted_trees {
using boosted_trees::models::DecisionTreeEnsembleResource;
// Creates a tree ensemble variable.
class CreateTreeEnsembleVariableOp : public OpKernel {
public:
explicit CreateTreeEnsembleVariableOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Get the tree ensemble config.
const Tensor* tree_ensemble_config_t;
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
&tree_ensemble_config_t));
auto* result = new DecisionTreeEnsembleResource();
if (!result->InitFromSerialized(tree_ensemble_config_t->scalar<tstring>()(),
stamp_token)) {
result->Unref();
OP_REQUIRES(
context, false,
errors::InvalidArgument("Unable to parse tree ensemble config."));
}
// Only create one, if one does not exist already. Report status for all
// other exceptions.
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
}
}
};
// Op for retrieving a model stamp token without having to serialize.
class TreeEnsembleStampTokenOp : public OpKernel {
public:
explicit TreeEnsembleStampTokenOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t));
output_stamp_token_t->scalar<int64>()() = ensemble_resource->stamp();
}
};
// Op for serializing a model.
class TreeEnsembleSerializeOp : public OpKernel {
public:
explicit TreeEnsembleSerializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
Tensor* output_stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
&output_stamp_token_t));
output_stamp_token_t->scalar<int64>()() = ensemble_resource->stamp();
Tensor* output_config_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(1, TensorShape(), &output_config_t));
output_config_t->scalar<tstring>()() =
ensemble_resource->SerializeAsString();
}
};
// Op for deserializing a tree ensemble variable from a checkpoint.
class TreeEnsembleDeserializeOp : public OpKernel {
public:
explicit TreeEnsembleDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Get the tree ensemble config.
const Tensor* tree_ensemble_config_t;
OP_REQUIRES_OK(context, context->input("tree_ensemble_config",
&tree_ensemble_config_t));
// Deallocate all the previous objects on the resource.
ensemble_resource->Reset();
OP_REQUIRES(
context,
ensemble_resource->InitFromSerialized(
tree_ensemble_config_t->scalar<tstring>()(), stamp_token),
errors::InvalidArgument("Unable to parse tree ensemble config."));
}
};
class TreeEnsembleUsedHandlersOp : public OpKernel {
public:
explicit TreeEnsembleUsedHandlersOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->GetAttr("num_all_handlers", &num_handlers_));
}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Only the Chief should run this Op and it is guaranteed to be in
// a consistent state so the stamps must always match.
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
Tensor* output_used_handlers_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("used_handlers_mask", {num_handlers_},
&output_used_handlers_t));
auto output_used_handlers = output_used_handlers_t->vec<bool>();
Tensor* output_num_used_handlers_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_used_handlers", {},
&output_num_used_handlers_t));
int handler_idx = 0;
std::vector<int64> used_handlers = ensemble_resource->GetUsedHandlers();
output_num_used_handlers_t->scalar<int64>()() = used_handlers.size();
for (int64 i = 0; i < num_handlers_; ++i) {
if (handler_idx >= used_handlers.size() ||
used_handlers[handler_idx] > i) {
output_used_handlers(i) = false;
} else {
OP_REQUIRES(context, used_handlers[handler_idx] == i,
errors::InvalidArgument("Handler IDs should be sorted."));
++handler_idx;
output_used_handlers(i) = true;
}
}
}
private:
int64 num_handlers_;
};
REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleIsInitializedOp").Device(DEVICE_CPU),
IsResourceInitialized<DecisionTreeEnsembleResource>);
REGISTER_KERNEL_BUILDER(Name("CreateTreeEnsembleVariable").Device(DEVICE_CPU),
CreateTreeEnsembleVariableOp);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStampToken").Device(DEVICE_CPU),
TreeEnsembleStampTokenOp);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleSerialize").Device(DEVICE_CPU),
TreeEnsembleSerializeOp);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleDeserialize").Device(DEVICE_CPU),
TreeEnsembleDeserializeOp);
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleUsedHandlers").Device(DEVICE_CPU),
TreeEnsembleUsedHandlersOp);
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -1,468 +0,0 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <algorithm>
#include <string>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h"
#include "tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::boosted_trees::learner::AveragingConfig;
using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig;
namespace tensorflow {
namespace boosted_trees {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig;
using boosted_trees::learner::LearningRateDropoutDrivenConfig;
using boosted_trees::models::DecisionTreeEnsembleResource;
using boosted_trees::models::MultipleAdditiveTrees;
using boosted_trees::utils::DropoutUtils;
using boosted_trees::utils::TensorUtils;
namespace {
const char* kLearnerConfigAttributeName = "learner_config";
const char* kSeedTensorName = "seed";
const char* kApplyDropoutAttributeName = "apply_dropout";
const char* kApplyAveragingAttributeName = "apply_averaging";
const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights";
const char* kPredictionsTensorName = "predictions";
const char* kLeafIndexTensorName = "leaf_index";
void CalculateTreesToInclude(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const std::vector<int32>& trees_to_drop, const int32 num_trees,
const bool only_finalized, const bool center_bias,
std::vector<int32>* trees_to_include) {
trees_to_include->reserve(num_trees - trees_to_drop.size());
int32 index = 0;
// This assumes that trees_to_drop is a sorted list of tree ids.
for (int32 tree = 0; tree < num_trees; ++tree) {
// Skip the tree if tree is in the list of trees_to_drop.
if (!trees_to_drop.empty() && index < trees_to_drop.size() &&
trees_to_drop[index] == tree) {
++index;
continue;
}
// Or skip if the tree is not finalized and only_finalized is set,
// with the exception of centering bias.
if (only_finalized && !(center_bias && tree == 0) &&
config.tree_metadata_size() > 0 &&
!config.tree_metadata(tree).is_finalized()) {
continue;
}
trees_to_include->push_back(tree);
}
}
} // namespace
class GradientTreesPredictionOp : public OpKernel {
public:
explicit GradientTreesPredictionOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("use_locking", &use_locking_));
OP_REQUIRES_OK(context, context->GetAttr("center_bias", &center_bias_));
OP_REQUIRES_OK(
context, context->GetAttr(kApplyDropoutAttributeName, &apply_dropout_));
LearnerConfig learner_config;
string learner_config_str;
OP_REQUIRES_OK(context, context->GetAttr(kLearnerConfigAttributeName,
&learner_config_str));
OP_REQUIRES(
context, ParseProtoUnlimited(&learner_config, learner_config_str),
errors::InvalidArgument("Unable to parse learner config config."));
num_classes_ = learner_config.num_classes();
OP_REQUIRES(context, num_classes_ >= 2,
errors::InvalidArgument("Number of classes must be >=2"));
OP_REQUIRES(
context, ParseProtoUnlimited(&learner_config, learner_config_str),
errors::InvalidArgument("Unable to parse learner config config."));
bool reduce_dim;
OP_REQUIRES_OK(context, context->GetAttr("reduce_dim", &reduce_dim));
prediction_vector_size_ = reduce_dim ? num_classes_ - 1 : num_classes_;
only_finalized_trees_ =
learner_config.growing_mode() == learner_config.WHOLE_TREE;
if (learner_config.has_learning_rate_tuner() &&
learner_config.learning_rate_tuner().tuner_case() ==
LearningRateConfig::kDropout) {
dropout_config_ = learner_config.learning_rate_tuner().dropout();
has_dropout_ = true;
} else {
has_dropout_ = false;
}
OP_REQUIRES_OK(context, context->GetAttr(kApplyAveragingAttributeName,
&apply_averaging_));
apply_averaging_ =
apply_averaging_ && learner_config.averaging_config().config_case() !=
AveragingConfig::CONFIG_NOT_SET;
if (apply_averaging_) {
averaging_config_ = learner_config.averaging_config();
// If there is averaging config, check that the values are correct.
switch (averaging_config_.config_case()) {
case AveragingConfig::kAverageLastNTreesFieldNumber: {
OP_REQUIRES(context, averaging_config_.average_last_n_trees() > 0,
errors::InvalidArgument(
"Average last n trees must be a positive number"));
break;
}
case AveragingConfig::kAverageLastPercentTreesFieldNumber: {
OP_REQUIRES(context,
averaging_config_.average_last_percent_trees() > 0 &&
averaging_config_.average_last_percent_trees() <= 1.0,
errors::InvalidArgument(
"Average last percent must be in (0,1] interval."));
break;
}
case AveragingConfig::CONFIG_NOT_SET: {
LOG(QFATAL) << "We should never get here.";
break;
}
}
}
}
void Compute(OpKernelContext* const context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
// Release the reference to the resource once we're done using it.
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource,
/*return_output_leaf_index=*/false);
} else {
DoCompute(context, ensemble_resource,
/*return_output_leaf_index=*/false);
}
}
protected:
// return_output_leaf_index is a boolean variable indicating whether to output
// leaf index in prediction. Though this class invokes only with this param
// value as false, the subclass GradientTreesPredictionVerboseOp will invoke
// with the true value.
virtual void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const bool return_output_leaf_index) {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
context, &dense_float_features_list));
// Read sparse float features list;
OpInputList sparse_float_feature_indices_list;
OpInputList sparse_float_feature_values_list;
OpInputList sparse_float_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
context, &sparse_float_feature_indices_list,
&sparse_float_feature_values_list,
&sparse_float_feature_shapes_list));
// Read sparse int features list;
OpInputList sparse_int_feature_indices_list;
OpInputList sparse_int_feature_values_list;
OpInputList sparse_int_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseIntFeatures(
context, &sparse_int_feature_indices_list,
&sparse_int_feature_values_list,
&sparse_int_feature_shapes_list));
// Infer batch size.
const int64 batch_size = TensorUtils::InferBatchSize(
dense_float_features_list, sparse_float_feature_shapes_list,
sparse_int_feature_shapes_list);
// Read batch features.
boosted_trees::utils::BatchFeatures batch_features(batch_size);
OP_REQUIRES_OK(
context,
batch_features.Initialize(
TensorUtils::OpInputListToTensorVec(dense_float_features_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_indices_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_values_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_shapes_list),
TensorUtils::OpInputListToTensorVec(
sparse_int_feature_indices_list),
TensorUtils::OpInputListToTensorVec(sparse_int_feature_values_list),
TensorUtils::OpInputListToTensorVec(
sparse_int_feature_shapes_list)));
std::vector<int32> dropped_trees;
std::vector<float> original_weights;
// Do dropout if needed.
if (apply_dropout_ && has_dropout_) {
// Read in seed and cast to uint64.
const Tensor* seed_t;
OP_REQUIRES_OK(context, context->input(kSeedTensorName, &seed_t));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_t->shape()),
errors::InvalidArgument("Seed must be a scalar."));
const uint64 seed = seed_t->scalar<int64>()();
std::unordered_set<int32> trees_not_to_drop;
if (center_bias_) {
trees_not_to_drop.insert(0);
}
if (ensemble_resource->decision_tree_ensemble().has_growing_metadata()) {
// We are in batch mode, the last tree is the tree that is being built,
// we can't drop it during dropout.
trees_not_to_drop.insert(ensemble_resource->num_trees() - 1);
}
const std::vector<float> weights = ensemble_resource->GetTreeWeights();
OP_REQUIRES_OK(context, DropoutUtils::DropOutTrees(
seed, dropout_config_, trees_not_to_drop,
weights, &dropped_trees, &original_weights));
}
// Prepare the list of trees to include in the prediction.
std::vector<int32> trees_to_include;
CalculateTreesToInclude(
ensemble_resource->decision_tree_ensemble(), dropped_trees,
ensemble_resource->decision_tree_ensemble().trees_size(),
only_finalized_trees_, center_bias_, &trees_to_include);
// Allocate output predictions matrix.
Tensor* output_predictions_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(kPredictionsTensorName,
{batch_size, prediction_vector_size_},
&output_predictions_t));
auto output_predictions = output_predictions_t->matrix<float>();
// Allocate output leaf index matrix.
Tensor* output_leaf_index_t = nullptr;
if (return_output_leaf_index) {
OP_REQUIRES_OK(context, context->allocate_output(
kLeafIndexTensorName,
{batch_size, ensemble_resource->num_trees()},
&output_leaf_index_t));
}
// Run predictor.
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
if (apply_averaging_) {
DecisionTreeEnsembleConfig adjusted =
ensemble_resource->decision_tree_ensemble();
const int start_averaging = std::max(
0.0,
averaging_config_.config_case() ==
AveragingConfig::kAverageLastNTreesFieldNumber
? adjusted.trees_size() - averaging_config_.average_last_n_trees()
: adjusted.trees_size() *
(1.0 - averaging_config_.average_last_percent_trees()));
const int num_ensembles = adjusted.trees_size() - start_averaging;
for (int i = start_averaging; i < adjusted.trees_size(); ++i) {
float weight = adjusted.tree_weights(i);
adjusted.mutable_tree_weights()->Set(
i, weight * (num_ensembles - i + start_averaging) / num_ensembles);
}
MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features,
worker_threads, output_predictions,
output_leaf_index_t);
} else {
MultipleAdditiveTrees::Predict(
ensemble_resource->decision_tree_ensemble(), trees_to_include,
batch_features, worker_threads, output_predictions,
output_leaf_index_t);
}
// Output dropped trees and original weights.
Tensor* output_dropout_info_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
kDropoutInfoOutputTensorName,
{2, static_cast<int64>(dropped_trees.size())},
&output_dropout_info_t));
auto output_dropout_info = output_dropout_info_t->matrix<float>();
for (int32 i = 0; i < dropped_trees.size(); ++i) {
output_dropout_info(0, i) = dropped_trees[i];
output_dropout_info(1, i) = original_weights[i];
}
}
private:
LearningRateDropoutDrivenConfig dropout_config_;
AveragingConfig averaging_config_;
bool only_finalized_trees_;
int num_classes_;
// What is the size of the output vector for predictions?
int prediction_vector_size_;
bool apply_dropout_;
bool center_bias_;
bool apply_averaging_;
bool use_locking_;
bool has_dropout_;
};
REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU),
GradientTreesPredictionOp);
// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp
// and have an additional output of tensor of rank 2 containing leaf ids for
// each tree where an instance ended up with.
class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
public:
explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context)
: GradientTreesPredictionOp(context) {}
protected:
void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
bool return_output_leaf_index) override {
GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
/*return_output_leaf_index=*/true);
}
};
REGISTER_KERNEL_BUILDER(
Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU),
GradientTreesPredictionVerboseOp);
class GradientTreesPartitionExamplesOp : public OpKernel {
public:
explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("use_locking", &use_locking_));
}
void Compute(OpKernelContext* const context) override {
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
// Gets the resource. Grabs the mutex but releases it.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
DoCompute(context, ensemble_resource);
} else {
DoCompute(context, ensemble_resource);
}
}
private:
void DoCompute(
OpKernelContext* context,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// The last non-finalized tree in the ensemble is by convention the
// one to partition on. If no such tree exists, a nodeless tree is
// created.
boosted_trees::trees::DecisionTreeConfig empty_tree_config;
const boosted_trees::trees::DecisionTreeConfig& tree_config =
(resource->num_trees() <= 0 ||
resource->LastTreeMetadata()->is_finalized())
? empty_tree_config
: *resource->LastTree();
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
context, &dense_float_features_list));
// Read sparse float features list;
OpInputList sparse_float_feature_indices_list;
OpInputList sparse_float_feature_values_list;
OpInputList sparse_float_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
context, &sparse_float_feature_indices_list,
&sparse_float_feature_values_list,
&sparse_float_feature_shapes_list));
// Read sparse int features list;
OpInputList sparse_int_feature_indices_list;
OpInputList sparse_int_feature_values_list;
OpInputList sparse_int_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseIntFeatures(
context, &sparse_int_feature_indices_list,
&sparse_int_feature_values_list,
&sparse_int_feature_shapes_list));
// Infer batch size.
const int64 batch_size = TensorUtils::InferBatchSize(
dense_float_features_list, sparse_float_feature_shapes_list,
sparse_int_feature_shapes_list);
// Read batch features.
boosted_trees::utils::BatchFeatures batch_features(batch_size);
OP_REQUIRES_OK(
context,
batch_features.Initialize(
TensorUtils::OpInputListToTensorVec(dense_float_features_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_indices_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_values_list),
TensorUtils::OpInputListToTensorVec(
sparse_float_feature_shapes_list),
TensorUtils::OpInputListToTensorVec(
sparse_int_feature_indices_list),
TensorUtils::OpInputListToTensorVec(sparse_int_feature_values_list),
TensorUtils::OpInputListToTensorVec(
sparse_int_feature_shapes_list)));
// Allocate output partitions vector.
Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, {batch_size}, &partition_ids_t));
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
learner::ExamplePartitioner::PartitionExamples(
tree_config, batch_features, worker_threads->NumThreads(),
worker_threads, partition_ids_t->vec<int32>().data());
}
private:
bool use_locking_;
};
REGISTER_KERNEL_BUILDER(
Name("GradientTreesPartitionExamples").Device(DEVICE_CPU),
GradientTreesPartitionExamplesOp);
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -1,985 +0,0 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <algorithm>
#include <iterator>
#include <string>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
using ::boosted_trees::QuantileConfig;
using boosted_trees::QuantileStreamResource;
using boosted_trees::utils::TensorUtils;
namespace {
const char* const kExampleWeightsName = "example_weights";
const char* const kMaxElementsName = "max_elements";
const char* const kNextStampTokenName = "next_stamp_token";
const char* const kStampTokenName = "stamp_token";
const char* const kAreBucketsReadyName = "are_buckets_ready";
const char* const kGenerateQuantiles = "generate_quantiles";
// Names for sparse arguments.
const char* const kNumSparseFeaturesName = "num_sparse_features";
const char* const kSparseBucketsName = "sparse_buckets";
const char* const kSparseValuesName = "sparse_values";
const char* const kSparseIndicesName = "sparse_indices";
const char* const kSparseSummariesName = "sparse_summaries";
const char* const kSparseConfigName = "sparse_config";
const char* const kSparseOutputTensorName = "sparse_quantiles";
// Names for dense arguments.
const char* const kDenseBucketsName = "dense_buckets";
const char* const kDenseConfigName = "dense_config";
const char* const kDenseOutputTensorName = "dense_quantiles";
const char* const kDenseSummariesName = "dense_summaries";
const char* const kDenseValuesName = "dense_values";
const char* const kNumDenseFeaturesName = "num_dense_features";
const char* const kResourceHandlesName = "quantile_accumulator_handles";
const char* const kNumQuantilesName = "num_quantiles";
const char* const kEpsilonName = "epsilon";
const char* const kBucketsName = "buckets";
const char* const kStreamStateName = "stream_state";
const char* const kSummariesName = "summaries";
using QuantileStream =
boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
using QuantileSummary =
boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
using QuantileSummaryEntry =
boosted_trees::quantiles::WeightedQuantilesSummary<float,
float>::SummaryEntry;
std::vector<float> GetBuckets(const int32 feature,
const OpInputList& buckets_list) {
const auto& buckets = buckets_list[feature].flat<float>();
std::vector<float> buckets_vector(buckets.data(),
buckets.data() + buckets.size());
return buckets_vector;
}
int32 GetFeatureDimension(const int32 feature_index, const int64 instance,
const OpInputList* const indices_list) {
if (indices_list != nullptr) {
// Sparse multidimensional.
return (*indices_list)[feature_index].matrix<int64>()(instance, 1);
}
// No indices, assume one-dimensional tensor.
return 0;
}
// Allows quantization for each of multiple dimensions of a sparse feature.
void QuantizeFeatures(
const string& output_name, const OpInputList& values_list,
const OpInputList& buckets_list,
const OpInputList* const
indices_list /** Optional, provide for sparse features **/,
OpKernelContext* const context) {
if (values_list.size() == 0) {
return;
}
OpOutputList output_list;
OP_REQUIRES_OK(context, context->output_list(output_name, &output_list));
for (int32 feature_index = 0; feature_index < values_list.size();
++feature_index) {
const Tensor& values_tensor = values_list[feature_index];
const int64 num_values = values_tensor.dim_size(0);
Tensor* output_t = nullptr;
// Output will have bucket id and dimension of the features for that bucket.
OP_REQUIRES_OK(
context, output_list.allocate(feature_index,
TensorShape({num_values, 2}), &output_t));
auto output = output_t->matrix<int32>();
const std::vector<float>& buckets_vector =
GetBuckets(feature_index, buckets_list);
auto flat_values = values_tensor.flat<float>();
for (int64 instance = 0; instance < num_values; ++instance) {
const float value = flat_values(instance);
CHECK(!buckets_vector.empty())
<< "Got empty buckets for feature " << feature_index;
auto bucket_iter =
std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
if (bucket_iter == buckets_vector.end()) {
--bucket_iter;
}
const int32 bucket =
static_cast<int32>(bucket_iter - buckets_vector.begin());
// Bucket id.
output(instance, 0) = bucket;
// Dimension.
output(instance, 1) =
GetFeatureDimension(feature_index, instance, indices_list);
}
}
}
// Validates attributes for the quantile ops.
Status ReadAndValidateAttributes(OpKernelConstruction* const context,
int* num_dense_features,
int* num_sparse_features) {
TF_RETURN_IF_ERROR(
context->GetAttr(kNumDenseFeaturesName, num_dense_features));
TF_RETURN_IF_ERROR(
context->GetAttr(kNumSparseFeaturesName, num_sparse_features));
if ((*num_dense_features) + (*num_sparse_features) == 0) {
return errors::InvalidArgument(
"Please provide at least sparse or dense features.");
}
return Status::OK();
}
void ParseConfig(OpKernelConstruction* const context, const string& name,
std::vector<QuantileConfig>* output) {
std::vector<string> serialized_config;
OP_REQUIRES_OK(context, context->GetAttr(name, &serialized_config));
output->reserve(serialized_config.size());
QuantileConfig tmp;
for (const auto& serialized_string : serialized_config) {
OP_REQUIRES(context, tmp.ParseFromString(serialized_string),
errors::InvalidArgument("Malformed QuantileConfig passed in."));
output->push_back(tmp);
}
}
// Generates quantiles on a finalized QuantileStream.
std::vector<float> GenerateBoundaries(const QuantileStream& stream,
int num_boundaries) {
std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
// Uniquify elements as we may get dupes.
auto end_it = std::unique(boundaries.begin(), boundaries.end());
boundaries.resize(std::distance(boundaries.begin(), end_it));
return boundaries;
}
// Generates quantiles on a finalized QuantileStream.
std::vector<float> GenerateQuantiles(const QuantileStream& stream,
int num_quantiles) {
// Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
// will be returned.
std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles);
CHECK_EQ(boundaries.size(), num_quantiles + 1);
return boundaries;
}
// Copies quantiles to output list.
void CopyBoundaries(OpKernelContext* const context,
const std::vector<float>& boundaries, const int64 index,
OpOutputList* output_list) {
// Output to tensor.
Tensor* output_t = nullptr;
OP_REQUIRES_OK(
context, output_list->allocate(
index, {static_cast<int64>(boundaries.size())}, &output_t));
auto* quantiles_flat = output_t->flat<float>().data();
memcpy(quantiles_flat, boundaries.data(), sizeof(float) * boundaries.size());
}
void CopySummaryToProto(const QuantileSummary& summary,
::boosted_trees::QuantileSummaryState* summary_proto) {
summary_proto->mutable_entries()->Reserve(summary.Size());
for (const auto& entry : summary.GetEntryList()) {
auto* new_entry = summary_proto->add_entries();
new_entry->set_value(entry.value);
new_entry->set_weight(entry.weight);
new_entry->set_min_rank(entry.min_rank);
new_entry->set_max_rank(entry.max_rank);
}
}
} // namespace
// Accumulator for Quantile Summaries.
REGISTER_RESOURCE_HANDLE_KERNEL(QuantileStreamResource);
REGISTER_KERNEL_BUILDER(
Name("QuantileAccumulatorIsInitialized").Device(DEVICE_CPU),
IsResourceInitialized<QuantileStreamResource>);
class CreateQuantileAccumulatorOp : public OpKernel {
public:
explicit CreateQuantileAccumulatorOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
OP_REQUIRES_OK(context,
context->GetAttr(kNumQuantilesName, &num_quantiles_));
OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
OP_REQUIRES_OK(context,
context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
}
void Compute(OpKernelContext* context) override {
// Only create one, if one does not exist already. Report status for all
// other exceptions. If one already exists, it unrefs the new one.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
// An epsilon value of zero could cause perfoamance issues and is therefore,
// disallowed.
OP_REQUIRES(
context, epsilon_ > 0,
errors::InvalidArgument("An epsilon value of zero is not allowed."));
auto result = new QuantileStreamResource(epsilon_, num_quantiles_,
max_elements_, generate_quantiles_,
stamp_token_t->scalar<int64>()());
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
}
}
private:
float epsilon_;
int32 num_quantiles_;
// An upper bound on the number of entries that the summaries might have
// for a feature.
int64 max_elements_;
bool generate_quantiles_;
};
REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU),
CreateQuantileAccumulatorOp);
// Adds a summary to the quantile summary stream.
class QuantileAccumulatorAddSummariesOp : public OpKernel {
public:
explicit QuantileAccumulatorAddSummariesOp(
OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
OpInputList resource_handle_list;
OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
&resource_handle_list));
OpInputList summary_list;
OP_REQUIRES_OK(context, context->input_list(kSummariesName, &summary_list));
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
boosted_trees::utils::ParallelFor(
resource_handle_list.size(), worker_threads->NumThreads(),
worker_threads,
[&context, &resource_handle_list, &summary_list, stamp_token](
int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
// If the stamp is invalid we drop the update.
if (!streams_resource->is_stamp_valid(stamp_token)) {
VLOG(1)
<< "Invalid stamp token in QuantileAccumulatorAddSummariesOp."
<< " Passed stamp token: " << stamp_token << " "
<< "Current token: " << streams_resource->stamp();
return;
}
protobuf::Arena arena;
::boosted_trees::QuantileSummaryState* summary_proto =
protobuf::Arena::CreateMessage<
::boosted_trees::QuantileSummaryState>(&arena);
OP_REQUIRES(
context,
ParseProtoUnlimited(
summary_proto,
summary_list[resource_handle_idx].scalar<tstring>()()),
errors::InvalidArgument("Unable to parse quantile summary."));
std::vector<QuantileSummaryEntry> entries;
entries.reserve(summary_proto->entries_size());
for (const auto& entry : summary_proto->entries()) {
entries.emplace_back(entry.value(), entry.weight(),
entry.min_rank(), entry.max_rank());
}
// Add the summary to the quantile stream.
streams_resource->stream(stamp_token)->PushSummary(entries);
}
});
}
};
REGISTER_KERNEL_BUILDER(
Name("QuantileAccumulatorAddSummaries").Device(DEVICE_CPU),
QuantileAccumulatorAddSummariesOp);
// Generates summaries for given set of float values, and the given config.
class MakeQuantileSummariesOp : public OpKernel {
public:
explicit MakeQuantileSummariesOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
ReadAndValidateAttributes(context, &num_dense_features_,
&num_sparse_features_));
OP_REQUIRES_OK(context, context->GetAttr(kEpsilonName, &epsilon_));
}
void Compute(OpKernelContext* const context) override {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
context, &dense_float_features_list));
// Read sparse float features list;
OpInputList sparse_float_feature_indices_list;
OpInputList sparse_float_feature_values_list;
OpInputList sparse_float_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
context, &sparse_float_feature_indices_list,
&sparse_float_feature_values_list,
&sparse_float_feature_shapes_list));
// Parse example weights and get batch size.
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input(kExampleWeightsName, &example_weights_t));
auto example_weights = example_weights_t->flat<float>();
const int64 batch_size = example_weights.size();
OpOutputList sparse_summaries_output_list;
OP_REQUIRES_OK(context,
context->output_list(kSparseSummariesName,
&sparse_summaries_output_list));
OpOutputList dense_summaries_output_list;
OP_REQUIRES_OK(context, context->output_list(kDenseSummariesName,
&dense_summaries_output_list));
auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
auto copy_over_summaries = [&](const QuantileStream& stream,
const int64 index,
OpOutputList* output_list) {
protobuf::Arena arena;
::boosted_trees::QuantileSummaryState* summary_proto =
protobuf::Arena::CreateMessage<
::boosted_trees::QuantileSummaryState>(&arena);
const auto& summary = stream.GetFinalSummary();
CopySummaryToProto(summary, summary_proto);
// Output to tensor.
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t));
SerializeToTString(*summary_proto, &output_t->scalar<tstring>()());
};
// These are blocks of ranges. We are iterating over both sparse and
// dense features i.e. [0, sparse_features.size() + dense_features.size()]
for (int64 i = begin; i < end; ++i) {
if (i < num_dense_features_) {
const int64 dense_index = i;
const auto dense_values =
dense_float_features_list[dense_index].flat<float>();
QuantileStream stream(epsilon_, batch_size + 1);
// Run quantile summary generation.
for (int64 j = 0; j < batch_size; ++j) {
stream.PushEntry(dense_values(j), example_weights(j));
}
stream.Finalize();
// Copy summaries to output.
copy_over_summaries(stream, dense_index,
&dense_summaries_output_list);
} else {
const int64 sparse_index = i - num_dense_features_;
const auto sparse_values =
sparse_float_feature_values_list[sparse_index].flat<float>();
const auto sparse_indices =
sparse_float_feature_indices_list[sparse_index].matrix<int64>();
const auto dense_shape =
sparse_float_feature_shapes_list[sparse_index].flat<int64>();
OP_REQUIRES(context, batch_size == dense_shape(0),
errors::InvalidArgument(
"Sparse column shape doesn't match the batch size."));
QuantileStream stream(epsilon_, batch_size + 1);
// Run quantile summary generation.
const int64 num_sparse_rows =
sparse_float_feature_indices_list[sparse_index].dim_size(0);
for (int64 j = 0; j < num_sparse_rows; ++j) {
const int64 example_id = sparse_indices(j, 0);
stream.PushEntry(sparse_values(j), example_weights(example_id));
}
stream.Finalize();
// Copy summaries to output.
copy_over_summaries(stream, sparse_index,
&sparse_summaries_output_list);
}
}
};
const int64 kCostPerUnit = 500 * batch_size;
const int64 num_features = num_sparse_features_ + num_dense_features_;
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, num_features,
kCostPerUnit, do_quantile_summary_gen);
}
private:
int num_dense_features_;
int num_sparse_features_;
float epsilon_;
};
REGISTER_KERNEL_BUILDER(Name("MakeQuantileSummaries").Device(DEVICE_CPU),
MakeQuantileSummariesOp);
// Serializes the state of streams.
class QuantileAccumulatorSerializeOp : public OpKernel {
public:
explicit QuantileAccumulatorSerializeOp(OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
int64 stamp_token = streams_resource->stamp();
Tensor* stream_state_t;
OP_REQUIRES_OK(context,
context->allocate_output(kStreamStateName, TensorShape({}),
&stream_state_t));
bool are_buckets_ready = streams_resource->are_buckets_ready();
// We are iterating over both dense and sparse features. First we go
// through the dense features and then the sparse features.
const QuantileStream& stream = *streams_resource->stream(stamp_token);
const std::vector<float>& boundaries =
are_buckets_ready ? streams_resource->boundaries(stamp_token)
: std::vector<float>();
protobuf::Arena arena;
::boosted_trees::QuantileStreamState* stream_proto =
protobuf::Arena::CreateMessage<::boosted_trees::QuantileStreamState>(
&arena);
for (const auto& summary : stream.SerializeInternalSummaries()) {
CopySummaryToProto(summary, stream_proto->add_summaries());
}
SerializeToTString(*stream_proto, &stream_state_t->scalar<tstring>()());
Tensor* buckets_t = nullptr;
OP_REQUIRES_OK(
context,
context->allocate_output(
kBucketsName, {static_cast<int64>(boundaries.size())}, &buckets_t));
auto* quantiles_flat = buckets_t->flat<float>().data();
memcpy(quantiles_flat, boundaries.data(),
sizeof(float) * boundaries.size());
Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(kStampTokenName, TensorShape({}),
&stamp_token_t));
stamp_token_t->scalar<int64>()() = stamp_token;
Tensor* are_buckets_ready_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(kAreBucketsReadyName, {},
&are_buckets_ready_t));
are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
}
};
REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorSerialize").Device(DEVICE_CPU),
QuantileAccumulatorSerializeOp);
// Serializes the state of streams.
class QuantileAccumulatorDeserializeOp : public OpKernel {
public:
explicit QuantileAccumulatorDeserializeOp(OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
int64 old_stamp_token = streams_resource->stamp();
const Tensor* stream_state_t;
OP_REQUIRES_OK(context, context->input(kStreamStateName, &stream_state_t));
const Tensor* buckets_t;
OP_REQUIRES_OK(context, context->input(kBucketsName, &buckets_t));
QuantileStream* stream = streams_resource->stream(old_stamp_token);
::boosted_trees::QuantileStreamState state_proto;
OP_REQUIRES(
context,
ParseProtoUnlimited(&state_proto, stream_state_t->scalar<tstring>()()),
errors::InvalidArgument("Unabnle to parse quantile stream state."));
std::vector<QuantileSummary> summaries;
summaries.reserve(state_proto.summaries_size());
std::vector<QuantileSummaryEntry> entries;
for (const auto& summary : state_proto.summaries()) {
entries.clear();
entries.reserve(summary.entries_size());
for (const auto& entry : summary.entries()) {
entries.emplace_back(entry.value(), entry.weight(), entry.min_rank(),
entry.max_rank());
}
summaries.emplace_back();
summaries[summaries.size() - 1].BuildFromSummaryEntries(entries);
}
stream->DeserializeInternalSummaries(summaries);
const auto& buckets = buckets_t->vec<float>();
std::vector<float> result;
result.reserve(buckets.size());
for (size_t i = 0; i < buckets.size(); ++i) {
result.push_back(buckets(i));
}
streams_resource->set_boundaries(old_stamp_token, result);
// Reset the stamp token.
const Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
streams_resource->set_stamp(stamp_token);
const Tensor* are_buckets_ready_t = nullptr;
OP_REQUIRES_OK(context,
context->input(kAreBucketsReadyName, &are_buckets_ready_t));
streams_resource->set_buckets_ready(are_buckets_ready_t->scalar<bool>()());
}
};
REGISTER_KERNEL_BUILDER(
Name("QuantileAccumulatorDeserialize").Device(DEVICE_CPU),
QuantileAccumulatorDeserializeOp);
// Flushes the quantile summary stream resource.
class QuantileAccumulatorFlushOp : public OpKernel {
public:
explicit QuantileAccumulatorFlushOp(OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input(kNextStampTokenName, &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
CHECK(streams_resource->is_stamp_valid(stamp_token))
<< "Invalid stamp token in QuantileAccumulatorFlushOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << streams_resource->stamp();
QuantileStream* stream = streams_resource->stream(stamp_token);
bool generate_quantiles = streams_resource->generate_quantiles();
stream->Finalize();
streams_resource->set_boundaries(
stamp_token,
generate_quantiles
? GenerateQuantiles(*stream, streams_resource->num_quantiles())
: GenerateBoundaries(*stream, streams_resource->num_quantiles()));
streams_resource->Reset(next_stamp_token);
}
};
REGISTER_KERNEL_BUILDER(Name("QuantileAccumulatorFlush").Device(DEVICE_CPU),
QuantileAccumulatorFlushOp);
// Flushes the quantile summary stream resource. This version computes the
// summary.
class QuantileAccumulatorFlushSummaryOp : public OpKernel {
public:
explicit QuantileAccumulatorFlushSummaryOp(
OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<QuantileStreamResource> streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input(kNextStampTokenName, &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
CHECK(streams_resource->is_stamp_valid(stamp_token))
<< "Invalid stamp token in QuantileAccumulatorFlushSummaryOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << streams_resource->stamp();
QuantileStream* stream = streams_resource->stream(stamp_token);
stream->Finalize();
protobuf::Arena arena;
::boosted_trees::QuantileSummaryState* summary_proto =
protobuf::Arena::CreateMessage<::boosted_trees::QuantileSummaryState>(
&arena);
const auto& summary = stream->GetFinalSummary();
CopySummaryToProto(summary, summary_proto);
// Output to tensor.
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}), &output_t));
SerializeToTString(*summary_proto, &output_t->scalar<tstring>()());
streams_resource->Reset(next_stamp_token);
}
};
REGISTER_KERNEL_BUILDER(
Name("QuantileAccumulatorFlushSummary").Device(DEVICE_CPU),
QuantileAccumulatorFlushSummaryOp);
// Get bucket boundaries from summaries.
class QuantileAccumulatorGetBucketsOp : public OpKernel {
public:
explicit QuantileAccumulatorGetBucketsOp(OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
OpInputList resource_handle_list;
OP_REQUIRES_OK(context, context->input_list(kResourceHandlesName,
&resource_handle_list));
OpOutputList are_buckets_ready_list;
OP_REQUIRES_OK(context, context->output_list(kAreBucketsReadyName,
&are_buckets_ready_list));
OpOutputList buckets_list;
OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
boosted_trees::utils::ParallelFor(
resource_handle_list.size(), worker_threads->NumThreads(),
worker_threads,
[&context, &resource_handle_list, &are_buckets_ready_list,
&buckets_list, stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
core::RefCountPtr<QuantileStreamResource> streams_resource;
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
// Remove the reference at the end of this scope.
mutex_lock l(*streams_resource->mutex());
bool are_buckets_ready =
streams_resource->is_stamp_valid(stamp_token) &&
streams_resource->are_buckets_ready();
Tensor* are_buckets_ready_t = nullptr;
OP_REQUIRES_OK(context,
are_buckets_ready_list.allocate(
resource_handle_idx, {}, &are_buckets_ready_t));
are_buckets_ready_t->scalar<bool>()() = are_buckets_ready;
const std::vector<float>& boundaries =
are_buckets_ready ? streams_resource->boundaries(stamp_token)
: std::vector<float>();
Tensor* output_t = nullptr;
OP_REQUIRES_OK(context, buckets_list.allocate(
resource_handle_idx,
{static_cast<int64>(boundaries.size())},
&output_t));
auto* quantiles_flat = output_t->flat<float>().data();
memcpy(quantiles_flat, boundaries.data(),
sizeof(float) * boundaries.size());
}
});
}
};
REGISTER_KERNEL_BUILDER(
Name("QuantileAccumulatorGetBuckets").Device(DEVICE_CPU),
QuantileAccumulatorGetBucketsOp);
// Generates buckets for given set of float values, and the given config.
class QuantileBucketsOp : public OpKernel {
public:
explicit QuantileBucketsOp(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
ReadAndValidateAttributes(context, &num_dense_features_,
&num_sparse_features_));
ParseConfig(context, kDenseConfigName, &dense_configs_);
OP_REQUIRES(context, dense_configs_.size() == num_dense_features_,
errors::InvalidArgument(
"Mismatch in number of dense quantile configs."));
ParseConfig(context, kSparseConfigName, &sparse_configs_);
OP_REQUIRES(context, sparse_configs_.size() == num_sparse_features_,
errors::InvalidArgument(
"Mismatch in number of sparse quantile configs."));
}
void Compute(OpKernelContext* const context) override {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
context, &dense_float_features_list));
// Read sparse float features list;
OpInputList sparse_float_feature_indices_list;
OpInputList sparse_float_feature_values_list;
OpInputList sparse_float_feature_shapes_list;
OP_REQUIRES_OK(context, TensorUtils::ReadSparseFloatFeatures(
context, &sparse_float_feature_indices_list,
&sparse_float_feature_values_list,
&sparse_float_feature_shapes_list));
// Parse example weights and get batch size.
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input(kExampleWeightsName, &example_weights_t));
auto example_weights = example_weights_t->flat<float>();
const int64 batch_size = example_weights.size();
OpOutputList sparse_buckets_output_list;
OP_REQUIRES_OK(context, context->output_list(kSparseBucketsName,
&sparse_buckets_output_list));
OpOutputList dense_buckets_output_list;
OP_REQUIRES_OK(context, context->output_list(kDenseBucketsName,
&dense_buckets_output_list));
auto do_quantile_bucket_gen = [&](const int64 begin, const int64 end) {
// These are blocks of ranges. We are iterating over both sparse and
// dense features i.e. [0, sparse_features.size() + dense_features.size()]
for (int64 i = begin; i < end; ++i) {
if (i < sparse_configs_.size()) {
const int64 sparse_index = i;
const auto sparse_values =
sparse_float_feature_values_list[sparse_index].flat<float>();
const auto sparse_indices =
sparse_float_feature_indices_list[sparse_index].matrix<int64>();
QuantileStream stream(sparse_configs_[sparse_index].eps(),
batch_size);
// Run quantile summary generation.
const int64 num_sparse_rows =
sparse_float_feature_indices_list[sparse_index].dim_size(0);
for (int64 j = 0; j < num_sparse_rows; ++j) {
const int64 example_id = sparse_indices(j, 0);
stream.PushEntry(sparse_values(j), example_weights(example_id));
}
stream.Finalize();
// Create buckets.
const auto boundaries = GenerateBoundaries(
stream, sparse_configs_[sparse_index].num_quantiles());
CopyBoundaries(context, boundaries, sparse_index,
&sparse_buckets_output_list);
} else {
const int64 dense_index = i - sparse_configs_.size();
const auto dense_values =
dense_float_features_list[dense_index].flat<float>();
QuantileStream stream(dense_configs_[dense_index].eps(), batch_size);
// Run quantile summary generation.
for (int64 j = 0; j < batch_size; ++j) {
stream.PushEntry(dense_values(j), example_weights(j));
}
stream.Finalize();
// Create buckets.
const auto boundaries = GenerateBoundaries(
stream, dense_configs_[dense_index].num_quantiles());
CopyBoundaries(context, boundaries, dense_index,
&dense_buckets_output_list);
}
}
};
const int64 kCostPerUnit = 500 * batch_size;
const int64 num_features = sparse_configs_.size() + dense_configs_.size();
const DeviceBase::CpuWorkerThreads& worker_threads =
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, num_features,
kCostPerUnit, do_quantile_bucket_gen);
}
private:
int num_dense_features_;
int num_sparse_features_;
std::vector<QuantileConfig> dense_configs_;
std::vector<QuantileConfig> sparse_configs_;
};
REGISTER_KERNEL_BUILDER(Name("QuantileBuckets").Device(DEVICE_CPU),
QuantileBucketsOp);
// Given the calculated quantiles thresholds and input data, this operation
// converts the input features into the buckets (categorical values), depending
// on which quantile they fall into.
class QuantilesOp : public OpKernel {
public:
explicit QuantilesOp(OpKernelConstruction* const context)
: OpKernel(context) {
int num_dense_features;
int num_sparse_features;
OP_REQUIRES_OK(context,
ReadAndValidateAttributes(context, &num_dense_features,
&num_sparse_features));
}
void Compute(OpKernelContext* const context) override {
// Dense features inputs
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, context->input_list(kDenseValuesName,
&dense_float_features_list));
OpInputList dense_buckets_list;
OP_REQUIRES_OK(context,
context->input_list(kDenseBucketsName, &dense_buckets_list));
if (dense_buckets_list.size() > 0) {
// Check the first tensor to make sure it is the right shape
OP_REQUIRES(
context,
tensorflow::TensorShapeUtils::IsVector(dense_buckets_list[0].shape()),
errors::InvalidArgument(
strings::Printf("Dense buckets should be flat vectors")));
}
// Sparse features inputs
OpInputList sparse_float_feature_values_list;
OP_REQUIRES_OK(context,
context->input_list(kSparseValuesName,
&sparse_float_feature_values_list));
OpInputList sparse_float_indices_list;
OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName,
&sparse_float_indices_list));
OpInputList sparse_buckets_list;
OP_REQUIRES_OK(
context, context->input_list(kSparseBucketsName, &sparse_buckets_list));
if (sparse_buckets_list.size() > 0) {
OP_REQUIRES(
context,
tensorflow::TensorShapeUtils::IsVector(
sparse_buckets_list[0].shape()),
errors::InvalidArgument("Sparse buckets should be flat vectors"));
}
// Quantize the feature values
QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list,
dense_buckets_list, nullptr, context);
QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list,
sparse_buckets_list, &sparse_float_indices_list, context);
}
};
REGISTER_KERNEL_BUILDER(Name("Quantiles").Device(DEVICE_CPU), QuantilesOp);
template <typename T>
class BucketizeWithInputBoundariesOp : public OpKernel {
public:
explicit BucketizeWithInputBoundariesOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& boundaries_tensor = context->input(1);
VLOG(1) << "boundaries has shape: "
<< boundaries_tensor.shape().DebugString();
auto boundaries = boundaries_tensor.flat<float>();
std::vector<T> boundaries_vector;
boundaries_vector.reserve(boundaries.size());
for (size_t i = 0; i < boundaries.size(); i++) {
boundaries_vector.push_back(boundaries(i));
VLOG(1) << "boundaries(" << i << ") : " << boundaries(i);
}
OP_REQUIRES(
context,
std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()),
errors::InvalidArgument("Expected sorted boundaries"));
const Tensor& input_tensor = context->input(0);
VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString()
<< " Dtype: " << tensorflow::DataTypeString(input_tensor.dtype());
auto input = input_tensor.flat<T>();
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
for (size_t i = 0; i < input.size(); i++) {
output(i) = CalculateBucketIndex(input(i), boundaries_vector);
}
}
private:
int32 CalculateBucketIndex(const T value, std::vector<T>& boundaries_vector) {
auto first_bigger_it = std::upper_bound(boundaries_vector.begin(),
boundaries_vector.end(), value);
int32 index = first_bigger_it - boundaries_vector.begin();
CHECK(index >= 0 && index <= boundaries_vector.size())
<< "Invalid bucket index: " << index
<< " boundaries_vector.size(): " << boundaries_vector.size();
return index;
}
};
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("BucketizeWithInputBoundaries") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
BucketizeWithInputBoundariesOp<T>);
REGISTER_KERNEL(int32);
REGISTER_KERNEL(int64);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
} // namespace tensorflow

File diff suppressed because it is too large Load Diff

View File

@ -1,782 +0,0 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <algorithm>
#include <iterator>
#include <map>
#include <string>
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h"
#include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h"
#include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace boosted_trees {
namespace {
const char* const kStampTokenName = "stamp_token";
const char* const kNextStampTokenName = "next_stamp_token";
struct PartitionKey {
PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {}
PartitionKey(int32 p, int64 f, int32 d)
: partition_id(p), feature_id(f), dimension(d) {}
bool operator==(const PartitionKey& other) const {
return (partition_id == other.partition_id) &&
(dimension == other.dimension) && (feature_id == other.feature_id);
}
// Compare for PartitionKey.
struct Less {
bool operator()(const PartitionKey& a, const PartitionKey& b) const {
if (a.partition_id < b.partition_id) {
return true;
}
if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) {
return true;
}
if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) &&
(a.feature_id < b.feature_id)) {
return true;
}
return false;
}
};
// Tree partition defined by traversing the tree to the leaf.
int32 partition_id;
// Feature column id.
int64 feature_id;
// Dimension within feature column.
int32 dimension;
};
template <typename GradientType, typename HessianType>
class StatsAccumulatorResource : public boosted_trees::StampedResource {
using StatsByPartition =
std::map<PartitionKey, std::pair<GradientType, HessianType>,
PartitionKey::Less>;
public:
StatsAccumulatorResource(const TensorShape& gradient_shape,
const TensorShape& hessian_shape)
: gradient_shape_(gradient_shape),
hessian_shape_(hessian_shape),
num_updates_(0) {
// If GradientType/HessianType is scalar float then the shapes should be
// scalar and vice versa.
CHECK_EQ((std::is_same<GradientType, float>::value),
TensorShapeUtils::IsScalar(gradient_shape));
CHECK_EQ((std::is_same<HessianType, float>::value),
TensorShapeUtils::IsScalar(hessian_shape));
}
string DebugString() const override {
return strings::StrCat("StatsAccumulatorResource[size=", values_.size(),
"]");
}
void Clear() {
values_.clear();
num_updates_ = 0;
}
tensorflow::mutex* mutex() { return &mu_; }
StatsByPartition* mutable_values() { return &values_; }
const StatsByPartition& values() const { return values_; }
const int64& num_updates() const { return num_updates_; }
void set_num_updates(int64 val) { num_updates_ = val; }
const TensorShape& gradient_shape() const { return gradient_shape_; }
const TensorShape& hessian_shape() const { return hessian_shape_; }
private:
// Key into a specific partition to accumulate stats for the specified feature
// id.
StatsByPartition values_;
const TensorShape gradient_shape_;
const TensorShape hessian_shape_;
int64 num_updates_;
tensorflow::mutex mu_;
TF_DISALLOW_COPY_AND_ASSIGN(StatsAccumulatorResource);
};
using StatsAccumulatorScalarResource = StatsAccumulatorResource<float, float>;
using StatsAccumulatorTensorResource =
StatsAccumulatorResource<std::vector<float>, std::vector<float>>;
void SerializeScalarAccumulatorToOutput(
const StatsAccumulatorScalarResource& resource, OpKernelContext* context) {
int64 num_slots = resource.values().size();
Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}),
&partition_ids_t));
auto partition_ids = partition_ids_t->vec<int32>();
// Feature ids tensor has ids of feature columns and their dimensions.
Tensor* feature_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
TensorShape({num_slots, 2}),
&feature_ids_t));
auto feature_ids = feature_ids_t->matrix<int64>();
Tensor* gradients_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(
"output_gradients", TensorShape({num_slots}), &gradients_t));
auto gradients = gradients_t->vec<float>();
Tensor* hessians_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("output_hessians",
TensorShape({num_slots}), &hessians_t));
auto hessians = hessians_t->vec<float>();
int i = 0;
for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension;
gradients(i) = iter.second.first;
hessians(i) = iter.second.second;
++i;
}
}
void SerializeTensorAccumulatorToOutput(
const StatsAccumulatorTensorResource& resource, OpKernelContext* context) {
int64 num_slots = resource.values().size();
Tensor* partition_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids",
TensorShape({num_slots}),
&partition_ids_t));
auto partition_ids = partition_ids_t->vec<int32>();
Tensor* feature_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids",
TensorShape({num_slots, 2}),
&feature_ids_t));
auto feature_ids = feature_ids_t->matrix<int64>();
TensorShape gradient_shape = resource.gradient_shape();
int64 num_gradient_elements = gradient_shape.num_elements();
gradient_shape.InsertDim(0, num_slots);
Tensor* gradients_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("output_gradients", gradient_shape,
&gradients_t));
auto gradients = gradients_t->flat_outer_dims<float>();
TensorShape hessian_shape = resource.hessian_shape();
int64 num_hessian_elements = hessian_shape.num_elements();
hessian_shape.InsertDim(0, num_slots);
Tensor* hessians_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("output_hessians",
hessian_shape, &hessians_t));
auto hessians = hessians_t->flat_outer_dims<float>();
int i = 0;
for (const auto& iter : resource.values()) {
partition_ids(i) = iter.first.partition_id;
feature_ids(i, 0) = iter.first.feature_id;
feature_ids(i, 1) = iter.first.dimension;
for (int j = 0; j < num_gradient_elements; ++j) {
gradients(i, j) = iter.second.first[j];
}
for (int j = 0; j < num_hessian_elements; ++j) {
hessians(i, j) = iter.second.second[j];
}
++i;
}
}
void AddToScalarAccumulator(
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t) {
resource->set_num_updates(resource->num_updates() + 1);
const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>();
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
const auto& gradients = gradients_t.vec<float>();
const auto& hessians = hessians_t.vec<float>();
int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) {
const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
feature_ids_and_dimensions(i, 1));
auto itr = stats_map->find(key);
if (itr != stats_map->end()) {
itr->second.first += gradients(i);
itr->second.second += hessians(i);
} else {
(*stats_map)[key] = {gradients(i), hessians(i)};
}
}
}
void AddToScalarAccumulator(
const core::RefCountPtr<StatsAccumulatorScalarResource>& resource,
OpKernelContext* context) {
const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
const Tensor* feature_ids_t;
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToScalarAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t);
}
void AddToTensorAccumulator(
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
const Tensor& partition_ids_t, const Tensor& feature_ids_t,
const Tensor& gradients_t, const Tensor& hessians_t,
OpKernelContext* context) {
resource->set_num_updates(resource->num_updates() + 1);
const TensorShape& partition_ids_shape = partition_ids_t.shape();
const auto& partition_ids = partition_ids_t.vec<int32>();
const auto& feature_ids_and_dimensions = feature_ids_t.matrix<int64>();
TensorShape gradients_shape = gradients_t.shape();
const auto& gradients = gradients_t.flat_outer_dims<float>();
TensorShape hessians_shape = hessians_t.shape();
const auto& hessians = hessians_t.flat_outer_dims<float>();
gradients_shape.RemoveDim(0);
hessians_shape.RemoveDim(0);
// TODO(soroush): Move gradient and hessian shape check to ShapeFn.
OP_REQUIRES(
context, gradients_shape == resource->gradient_shape(),
errors::InvalidArgument(strings::StrCat(
"Gradients dimensions must match: ", gradients_shape.DebugString(),
", ", resource->gradient_shape().DebugString())));
OP_REQUIRES(
context, hessians_shape == resource->hessian_shape(),
errors::InvalidArgument(strings::StrCat(
"Hessian dimensions must match: ", hessians_shape.DebugString(), ", ",
resource->hessian_shape().DebugString())));
int64 num_updates = partition_ids_shape.dim_size(0);
auto stats_map = resource->mutable_values();
for (int64 i = 0; i < num_updates; ++i) {
const auto key =
PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0),
feature_ids_and_dimensions(i, 1));
auto itr = stats_map->find(key);
if (itr == stats_map->end()) {
std::vector<float> new_gradients(gradients_shape.num_elements());
for (int j = 0; j < gradients_shape.num_elements(); ++j) {
new_gradients[j] = gradients(i, j);
}
std::vector<float> new_hessians(hessians_shape.num_elements());
for (int j = 0; j < hessians_shape.num_elements(); ++j) {
new_hessians[j] = hessians(i, j);
}
(*stats_map)[key] = {new_gradients, new_hessians};
} else {
auto& stored_gradients = itr->second.first;
for (int j = 0; j < gradients_shape.num_elements(); ++j) {
stored_gradients[j] += gradients(i, j);
}
auto& stored_hessians = itr->second.second;
for (int j = 0; j < hessians_shape.num_elements(); ++j) {
stored_hessians[j] += hessians(i, j);
}
}
}
}
void AddToTensorAccumulator(
const core::RefCountPtr<StatsAccumulatorTensorResource>& resource,
OpKernelContext* context) {
const Tensor* partition_ids_t;
OP_REQUIRES_OK(context, context->input("partition_ids", &partition_ids_t));
const Tensor* feature_ids_t;
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
AddToTensorAccumulator(resource, *partition_ids_t, *feature_ids_t,
*gradients_t, *hessians_t, context);
}
} // namespace
REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorScalarResource);
REGISTER_RESOURCE_HANDLE_KERNEL(StatsAccumulatorTensorResource);
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorScalarIsInitialized").Device(DEVICE_CPU),
IsResourceInitialized<StatsAccumulatorScalarResource>);
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorTensorIsInitialized").Device(DEVICE_CPU),
IsResourceInitialized<StatsAccumulatorTensorResource>);
class CreateStatsAccumulatorScalarOp : public OpKernel {
public:
explicit CreateStatsAccumulatorScalarOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
TensorShape gradient_shape = TensorShape({});
TensorShape hessian_shape = TensorShape({});
auto* result =
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape);
result->set_stamp(stamp_token_t->scalar<int64>()());
// Only create one, if one does not exist already. Report status for all
// other exceptions. If one already exists, it unrefs the new one.
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
}
}
};
REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorScalar").Device(DEVICE_CPU),
CreateStatsAccumulatorScalarOp);
class CreateStatsAccumulatorTensorOp : public OpKernel {
public:
explicit CreateStatsAccumulatorTensorOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
const Tensor* gradient_shape_t;
OP_REQUIRES_OK(
context, context->input("per_slot_gradient_shape", &gradient_shape_t));
const Tensor* hessian_shape_t;
OP_REQUIRES_OK(context,
context->input("per_slot_hessian_shape", &hessian_shape_t));
TensorShape gradient_shape = TensorShape(gradient_shape_t->vec<int64>());
TensorShape hessian_shape = TensorShape(hessian_shape_t->vec<int64>());
auto* result =
new StatsAccumulatorTensorResource(gradient_shape, hessian_shape);
result->set_stamp(stamp_token_t->scalar<int64>()());
// Only create one, if one does not exist already. Report status for all
// other exceptions. If one already exists, it unrefs the new one.
auto status = CreateResource(context, HandleFromInput(context, 0), result);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
OP_REQUIRES(context, false, status);
}
}
};
REGISTER_KERNEL_BUILDER(Name("CreateStatsAccumulatorTensor").Device(DEVICE_CPU),
CreateStatsAccumulatorTensorOp);
class StatsAccumulatorScalarAddOp : public OpKernel {
public:
explicit StatsAccumulatorScalarAddOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
OpInputList resource_handle_list;
OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
&resource_handle_list));
OpInputList partition_ids_list;
OP_REQUIRES_OK(context,
context->input_list("partition_ids", &partition_ids_list));
OpInputList feature_ids_list;
OP_REQUIRES_OK(context,
context->input_list("feature_ids", &feature_ids_list));
OpInputList gradients_list;
OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
OpInputList hessians_list;
OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
boosted_trees::utils::ParallelFor(
resource_handle_list.size(), worker_threads->NumThreads(),
worker_threads,
[&context, &resource_handle_list, &partition_ids_list,
&feature_ids_list, &gradients_list, &hessians_list,
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
mutex_lock l(*resource->mutex());
// If the stamp is invalid we drop the update.
if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << resource->stamp();
return;
}
AddToScalarAccumulator(resource,
partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx],
hessians_list[resource_handle_idx]);
}
});
}
};
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarAdd").Device(DEVICE_CPU),
StatsAccumulatorScalarAddOp);
class StatsAccumulatorTensorAddOp : public OpKernel {
public:
explicit StatsAccumulatorTensorAddOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
OpInputList resource_handle_list;
OP_REQUIRES_OK(context, context->input_list("stats_accumulator_handles",
&resource_handle_list));
OpInputList partition_ids_list;
OP_REQUIRES_OK(context,
context->input_list("partition_ids", &partition_ids_list));
OpInputList feature_ids_list;
OP_REQUIRES_OK(context,
context->input_list("feature_ids", &feature_ids_list));
OpInputList gradients_list;
OP_REQUIRES_OK(context, context->input_list("gradients", &gradients_list));
OpInputList hessians_list;
OP_REQUIRES_OK(context, context->input_list("hessians", &hessians_list));
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
boosted_trees::utils::ParallelFor(
resource_handle_list.size(), worker_threads->NumThreads(),
worker_threads,
[&context, &resource_handle_list, &partition_ids_list,
&feature_ids_list, &gradients_list, &hessians_list,
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
const ResourceHandle& handle =
resource_handle_list[resource_handle_idx]
.flat<ResourceHandle>()(0);
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, handle, &resource));
mutex_lock l(*resource->mutex());
// If the stamp is invalid we drop the update.
if (!resource->is_stamp_valid(stamp_token)) {
VLOG(1) << "Invalid stamp token in StatsAccumulatorScalarAddOp. "
<< "Passed stamp token: " << stamp_token << " "
<< "Current token: " << resource->stamp();
return;
}
AddToTensorAccumulator(resource,
partition_ids_list[resource_handle_idx],
feature_ids_list[resource_handle_idx],
gradients_list[resource_handle_idx],
hessians_list[resource_handle_idx], context);
}
});
}
};
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorAdd").Device(DEVICE_CPU),
StatsAccumulatorTensorAddOp);
class StatsAccumulatorScalarFlushOp : public OpKernel {
public:
explicit StatsAccumulatorScalarFlushOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in
// a consistent state.
CHECK(resource->is_stamp_valid(stamp_token));
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input(kNextStampTokenName, &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
CHECK(stamp_token != next_stamp_token);
SerializeScalarAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = resource->num_updates();
resource->Clear();
resource->set_stamp(next_stamp_token);
}
};
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorScalarFlush").Device(DEVICE_CPU),
StatsAccumulatorScalarFlushOp);
class StatsAccumulatorTensorFlushOp : public OpKernel {
public:
explicit StatsAccumulatorTensorFlushOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input(kNextStampTokenName, &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
// If the stamp is invalid we restart the PS. It shouldn't happen since
// only Chief should call this function and chief is guaranteed to be in
// a consistent state.
CHECK(resource->is_stamp_valid(stamp_token));
CHECK(stamp_token != next_stamp_token);
SerializeTensorAccumulatorToOutput(*resource, context);
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = resource->num_updates();
resource->Clear();
resource->set_stamp(next_stamp_token);
}
};
REGISTER_KERNEL_BUILDER(Name("StatsAccumulatorTensorFlush").Device(DEVICE_CPU),
StatsAccumulatorTensorFlushOp);
class StatsAccumulatorScalarDeserializeOp : public OpKernel {
public:
explicit StatsAccumulatorScalarDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
// Check the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
resource->Clear();
resource->set_stamp(stamp_token);
AddToScalarAccumulator(resource, context);
const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
resource->set_num_updates(num_updates_t->scalar<int64>()());
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorScalarDeserialize").Device(DEVICE_CPU),
StatsAccumulatorScalarDeserializeOp);
class StatsAccumulatorTensorDeserializeOp : public OpKernel {
public:
explicit StatsAccumulatorTensorDeserializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
// Check the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
resource->Clear();
resource->set_stamp(stamp_token);
AddToTensorAccumulator(resource, context);
const Tensor* num_updates_t;
OP_REQUIRES_OK(context, context->input("num_updates", &num_updates_t));
resource->set_num_updates(num_updates_t->scalar<int64>()());
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorTensorDeserialize").Device(DEVICE_CPU),
StatsAccumulatorTensorDeserializeOp);
class StatsAccumulatorScalarSerializeOp : public OpKernel {
public:
explicit StatsAccumulatorScalarSerializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorScalarResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
SerializeScalarAccumulatorToOutput(*resource, context);
Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t));
stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = resource->num_updates();
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorScalarSerialize").Device(DEVICE_CPU),
StatsAccumulatorScalarSerializeOp);
class StatsAccumulatorTensorSerializeOp : public OpKernel {
public:
explicit StatsAccumulatorTensorSerializeOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
core::RefCountPtr<StatsAccumulatorTensorResource> resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&resource));
mutex_lock l(*resource->mutex());
SerializeTensorAccumulatorToOutput(*resource, context);
Tensor* stamp_token_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("stamp_token", TensorShape({}),
&stamp_token_t));
stamp_token_t->scalar<int64>()() = resource->stamp();
Tensor* num_updates_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("num_updates", TensorShape({}),
&num_updates_t));
num_updates_t->scalar<int64>()() = resource->num_updates();
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorTensorSerialize").Device(DEVICE_CPU),
StatsAccumulatorTensorSerializeOp);
class StatsAccumulatorScalarMakeSummaryOp : public OpKernel {
public:
explicit StatsAccumulatorScalarMakeSummaryOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
TensorShape gradient_shape = TensorShape({});
TensorShape hessian_shape = TensorShape({});
core::RefCountPtr<StatsAccumulatorScalarResource> resource(
new StatsAccumulatorScalarResource(gradient_shape, hessian_shape));
// Check the stamp token.
AddToScalarAccumulator(resource, context);
SerializeScalarAccumulatorToOutput(*resource, context);
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorScalarMakeSummary").Device(DEVICE_CPU),
StatsAccumulatorScalarMakeSummaryOp);
class StatsAccumulatorTensorMakeSummaryOp : public OpKernel {
public:
explicit StatsAccumulatorTensorMakeSummaryOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
TensorShape gradients_shape = gradients_t->shape();
gradients_shape.RemoveDim(0);
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
TensorShape hessians_shape = hessians_t->shape();
hessians_shape.RemoveDim(0);
core::RefCountPtr<StatsAccumulatorTensorResource> resource(
new StatsAccumulatorTensorResource(gradients_shape, hessians_shape));
// Check the stamp token.
AddToTensorAccumulator(resource, context);
SerializeTensorAccumulatorToOutput(*resource, context);
}
};
REGISTER_KERNEL_BUILDER(
Name("StatsAccumulatorTensorMakeSummary").Device(DEVICE_CPU),
StatsAccumulatorTensorMakeSummaryOp);
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -1,996 +0,0 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <vector>
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace tensorflow {
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
namespace boosted_trees {
namespace {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig;
using boosted_trees::models::DecisionTreeEnsembleResource;
using boosted_trees::trees::DecisionTreeConfig;
using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode;
using boosted_trees::trees::TreeNodeMetadata;
using boosted_trees::utils::DropoutUtils;
// SplitCandidate holds the split candidate node along with the stats.
struct SplitCandidate {
// Id of handler that generated the split candidate.
int64 handler_id;
// Split gain.
float gain;
// Split info.
learner::SplitInfo split_info;
// Oblivious split info.
learner::ObliviousSplitInfo oblivious_split_info;
};
// Checks that the leaf is not empty.
bool IsLeafWellFormed(const Leaf& leaf) {
return leaf.has_sparse_vector() || leaf.has_vector();
}
// Helper method to update the best split per partition given
// a current candidate.
void UpdateBestSplit(
const boosted_trees::learner::LearnerConfig& learner_config,
int32 partition_id, SplitCandidate* split,
std::map<int32, SplitCandidate>* best_splits) {
// Don't consider nodeless splits.
if (TF_PREDICT_FALSE(split->split_info.split_node().node_case() ==
TreeNode::NODE_NOT_SET)) {
return;
}
// Don't consider negative splits if we're pre-pruning the tree.
// Note that zero-gain splits are acceptable as they're mostly doing as well
// as what bias centering in that partition would do.
if (learner_config.pruning_mode() ==
boosted_trees::learner::LearnerConfig::PRE_PRUNE &&
split->gain < 0) {
return;
}
// If the current node is pure, one of the leafs will be empty, so the split
// is meaningless and we should not split.
if (!(IsLeafWellFormed(split->split_info.right_child()) &&
IsLeafWellFormed(split->split_info.left_child()))) {
VLOG(1) << "Split does not actually split anything";
return;
}
// Take the split if we don't have a candidate yet.
auto best_split_it = best_splits->find(partition_id);
if (best_split_it == best_splits->end()) {
best_splits->insert(std::make_pair(partition_id, std::move(*split)));
return;
}
// Determine if best split so far needs to be replaced.
SplitCandidate& best_split = best_split_it->second;
if (TF_PREDICT_FALSE(split->gain == best_split.gain)) {
// Tie break on node case preferring simpler tree node types.
VLOG(2) << "Attempting to tie break with smaller node case. "
<< "(current split: " << split->split_info.split_node().node_case()
<< ", best split: "
<< best_split.split_info.split_node().node_case() << ")";
if (split->split_info.split_node().node_case() <
best_split.split_info.split_node().node_case()) {
best_split = std::move(*split);
} else if (split->split_info.split_node().node_case() ==
best_split.split_info.split_node().node_case()) {
// Tie break on handler Id.
VLOG(2) << "Tie breaking with higher handler Id. "
<< "(current split: " << split->handler_id
<< ", best split: " << best_split.handler_id << ")";
if (split->handler_id > best_split.handler_id) {
best_split = std::move(*split);
}
}
} else if (split->gain > best_split.gain) {
best_split = std::move(*split);
}
}
// Helper method to check whether a node is a terminal node in that it
// only has leaf nodes as children.
bool IsTerminalSplitNode(const size_t node_id,
const std::vector<int32>& children,
const std::vector<TreeNode>& nodes) {
for (const int32 child_id : children) {
const auto& child_node = nodes[child_id];
CHECK(child_node.node_case() != TreeNode::NODE_NOT_SET);
if (child_node.node_case() != TreeNode::kLeaf) {
return false;
}
}
return true;
}
// Helper method to recursively prune the tree in a depth-first fashion.
void RecursivePruneTree(const size_t node_id, std::vector<TreeNode>* nodes) {
// Base case when we reach a leaf.
TreeNode& tree_node = (*nodes)[node_id];
CHECK(tree_node.node_case() != TreeNode::NODE_NOT_SET);
if (tree_node.node_case() == TreeNode::kLeaf) {
return;
}
// Traverse node children first and recursively prune their sub-trees.
const std::vector<int32> children =
boosted_trees::trees::DecisionTree::GetChildren(tree_node);
for (const int32 child_id : children) {
RecursivePruneTree(child_id, nodes);
}
// Two conditions must be satisfied to prune the node:
// 1- The split gain is negative.
// 2- After depth-first pruning, the node only has leaf children.
TreeNodeMetadata* node_metadata = tree_node.mutable_node_metadata();
if (node_metadata->gain() < 0 &&
IsTerminalSplitNode(node_id, children, (*nodes))) {
// Clear node children.
for (const int32 child_id : children) {
auto& child_node = (*nodes)[child_id];
child_node.Clear();
}
// Change node back into leaf.
(*tree_node.mutable_leaf()) = *node_metadata->mutable_original_leaf();
// Clear gain for leaf node.
tree_node.clear_node_metadata();
} else {
// Clear original leaf as it's no longer needed for back-track pruning.
node_metadata->clear_original_leaf();
}
}
} // namespace
class CenterTreeEnsembleBiasOp : public OpKernel {
public:
explicit CenterTreeEnsembleBiasOp(OpKernelConstruction* const context)
: OpKernel(context) {
// Read learner config.
string serialized_learner_config;
OP_REQUIRES_OK(context, context->GetAttr("learner_config",
&serialized_learner_config));
OP_REQUIRES(context,
learner_config_.ParseFromString(serialized_learner_config),
errors::InvalidArgument("Unable to parse learner config."));
// Read centering epsilon.
OP_REQUIRES_OK(context,
context->GetAttr("centering_epsilon", &centering_epsilon_));
}
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Only the Chief should run this Op and it is guaranteed to be in
// a consistent state so the stamps must always match.
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
// Get the next stamp token.
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input("next_stamp_token", &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
CHECK(stamp_token != next_stamp_token);
// Update the ensemble stamp.
ensemble_resource->set_stamp(next_stamp_token);
// Get the delta updates.
const Tensor* delta_updates_t;
OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t));
auto delta_updates = delta_updates_t->vec<float>();
const int64 logits_dimension = delta_updates_t->dim_size(0);
// Get the bias.
boosted_trees::trees::Leaf* const bias =
RetrieveBias(ensemble_resource, logits_dimension);
CHECK(bias->has_vector());
// Update the bias.
float total_delta = 0;
auto* bias_vec = bias->mutable_vector();
for (size_t idx = 0; idx < bias->vector().value_size(); ++idx) {
float delta = delta_updates(idx);
bias_vec->set_value(idx, bias_vec->value(idx) + delta);
total_delta += std::abs(delta);
}
// Make a centering continuation decision based on current update.
bool continue_centering = total_delta > centering_epsilon_;
if (continue_centering) {
VLOG(1) << "Continuing to center bias, delta=" << total_delta;
} else {
VLOG(1) << "Done centering bias, delta=" << total_delta;
ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
}
Tensor* continue_centering_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("continue_centering", TensorShape({}),
&continue_centering_t));
continue_centering_t->scalar<bool>()() = continue_centering;
}
private:
// Helper method to retrieve the bias from the tree ensemble.
Leaf* RetrieveBias(
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
int64 logits_dimension) {
const int32 num_trees = ensemble_resource->num_trees();
if (num_trees <= 0) {
// Add a new bias leaf.
ensemble_resource->IncrementAttempts();
boosted_trees::trees::DecisionTreeConfig* const tree_config =
ensemble_resource->AddNewTree(1.0);
auto* const leaf = tree_config->add_nodes()->mutable_leaf();
for (size_t idx = 0; idx < logits_dimension; ++idx) {
leaf->mutable_vector()->add_value(0.0);
}
return leaf;
} else if (num_trees == 1) {
// Confirms that the only tree is a bias and returns its leaf.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
ensemble_resource->LastTree();
CHECK_EQ(tree_config->nodes_size(), 1);
CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf);
return tree_config->mutable_nodes(0)->mutable_leaf();
} else {
LOG(FATAL) << "Unable to center bias on an already grown ensemble";
}
}
boosted_trees::learner::LearnerConfig learner_config_;
float centering_epsilon_;
};
REGISTER_KERNEL_BUILDER(Name("CenterTreeEnsembleBias").Device(DEVICE_CPU),
CenterTreeEnsembleBiasOp);
class GrowTreeEnsembleOp : public OpKernel {
public:
explicit GrowTreeEnsembleOp(OpKernelConstruction* const context)
: OpKernel(context) {
// Read number of handlers, note that this is the static number of
// all handlers but any subset of these handlers may be active at a time.
OP_REQUIRES_OK(context, context->GetAttr("num_handlers", &num_handlers_));
OP_REQUIRES_OK(context, context->GetAttr("center_bias", &center_bias_));
// Read learner config.
string serialized_learner_config;
OP_REQUIRES_OK(context, context->GetAttr("learner_config",
&serialized_learner_config));
OP_REQUIRES(context,
learner_config_.ParseFromString(serialized_learner_config),
errors::InvalidArgument("Unable to parse learner config."));
// Determine whether dropout was used when building this tree.
if (learner_config_.has_learning_rate_tuner() &&
learner_config_.learning_rate_tuner().tuner_case() ==
LearningRateConfig::kDropout) {
dropout_config_ = learner_config_.learning_rate_tuner().dropout();
dropout_was_applied_ = true;
} else {
dropout_was_applied_ = false;
}
}
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
mutex_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Only the Chief should run this Op and it is guaranteed to be in
// a consistent state so the stamps must always match.
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
// Get the next stamp token.
const Tensor* next_stamp_token_t;
OP_REQUIRES_OK(context,
context->input("next_stamp_token", &next_stamp_token_t));
int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
CHECK(stamp_token != next_stamp_token);
// Update the ensemble stamp regardless of whether a layer
// or tree is actually grown.
ensemble_resource->set_stamp(next_stamp_token);
// Read the learning_rate.
const Tensor* learning_rate_t;
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
float learning_rate = learning_rate_t->scalar<float>()();
// Read the weak learner type to use.
const Tensor* weak_learner_type_t;
OP_REQUIRES_OK(context,
context->input("weak_learner_type", &weak_learner_type_t));
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
const Tensor* seed_t;
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
// Cast seed to uint64.
const uint64 dropout_seed = seed_t->scalar<int64>()();
// Read partition Ids, gains and split candidates.
OpInputList partition_ids_list;
OpInputList gains_list;
OpInputList splits_list;
OP_REQUIRES_OK(context,
context->input_list("partition_ids", &partition_ids_list));
OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
OP_REQUIRES_OK(context, context->input_list("splits", &splits_list));
// Increment attempt stats.
ensemble_resource->IncrementAttempts();
// Find best splits for each active partition.
std::map<int32, SplitCandidate> best_splits;
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
FindBestSplitsPerPartitionNormal(context, partition_ids_list,
gains_list, splits_list, &best_splits);
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
FindBestSplitOblivious(context, gains_list, splits_list, &best_splits);
break;
}
}
// No-op if no new splits can be considered.
if (best_splits.empty()) {
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
return;
}
// Get the max tree depth.
const Tensor* max_tree_depth_t;
OP_REQUIRES_OK(context,
context->input("max_tree_depth", &max_tree_depth_t));
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
// Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree.
DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(
ensemble_resource, learning_rate, dropout_seed, max_tree_depth,
weak_learner_type);
// Split tree nodes.
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
for (auto& split_entry : best_splits) {
SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
ensemble_resource);
}
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
}
}
// Post-prune finalized tree if needed.
if (learner_config_.pruning_mode() ==
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
ensemble_resource->LastTreeMetadata()->is_finalized()) {
VLOG(2) << "Post-pruning finalized tree.";
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
}
PruneTree(tree_config);
// If after post-pruning the whole tree has no gain, remove the tree
// altogether from the ensemble.
if (tree_config->nodes_size() <= 0) {
ensemble_resource->RemoveLastTree();
}
if ((ensemble_resource->num_trees() == 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) &&
learner_config_.has_each_tree_start() &&
learner_config_.each_tree_start().nodes_size() > 0) {
DCHECK_GT(learner_config_.each_tree_start_num_layers(), 0);
// Add new dummy tree
boosted_trees::trees::DecisionTreeConfig* const tree_config =
ensemble_resource->AddNewTree(learning_rate);
VLOG(1) << "Adding a new forced tree";
*tree_config = learner_config_.each_tree_start();
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
ensemble_resource->LastTreeMetadata();
tree_metadata->set_is_finalized(max_tree_depth <= 1);
tree_metadata->set_num_tree_weight_updates(1);
tree_metadata->set_num_layers_grown(
learner_config_.each_tree_start_num_layers());
}
}
}
private:
// Helper method which effectively does a reduce over all split candidates
// and finds the best split for each partition.
void FindBestSplitsPerPartitionNormal(
OpKernelContext* const context, const OpInputList& partition_ids_list,
const OpInputList& gains_list, const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
// TODO(salehay): Is this worth parallelizing?
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
const auto& gains = gains_list[handler_id].vec<float>();
const auto& splits = splits_list[handler_id].vec<tstring>();
OP_REQUIRES(context, partition_ids.size() == gains.size(),
errors::InvalidArgument(
"Inconsistent partition Ids and gains tensors: ",
partition_ids.size(), " != ", gains.size()));
OP_REQUIRES(context, partition_ids.size() == splits.size(),
errors::InvalidArgument(
"Inconsistent partition Ids and splits tensors: ",
partition_ids.size(), " != ", splits.size()));
for (size_t candidate_idx = 0; candidate_idx < splits.size();
++candidate_idx) {
// Get current split candidate.
const auto& partition_id = partition_ids(candidate_idx);
const auto& gain = gains(candidate_idx);
const auto& serialized_split = splits(candidate_idx);
SplitCandidate split;
split.handler_id = handler_id;
split.gain = gain;
OP_REQUIRES(context, split.split_info.ParseFromString(serialized_split),
errors::InvalidArgument("Unable to parse split info."));
// Update best split for partition based on the current candidate.
UpdateBestSplit(learner_config_, partition_id, &split, best_splits);
}
}
}
void FindBestSplitOblivious(OpKernelContext* const context,
const OpInputList& gains_list,
const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
const auto& gains = gains_list[handler_id].vec<float>();
const auto& splits = splits_list[handler_id].vec<tstring>();
OP_REQUIRES(context, gains.size() == 1,
errors::InvalidArgument(
"Gains size must be one for oblivious weak learner: ",
gains.size(), " != ", 1));
OP_REQUIRES(context, splits.size() == 1,
errors::InvalidArgument(
"Splits size must be one for oblivious weak learner: ",
splits.size(), " != ", 1));
// Get current split candidate.
const auto& gain = gains(0);
const auto& serialized_split = splits(0);
SplitCandidate split;
split.handler_id = handler_id;
split.gain = gain;
OP_REQUIRES(
context, split.oblivious_split_info.ParseFromString(serialized_split),
errors::InvalidArgument("Unable to parse oblivious split info."));
auto split_info = split.oblivious_split_info;
CHECK(split_info.children_size() % 2 == 0)
<< "The oblivious split should generate an even number of children: "
<< split_info.children_size();
// If every node is pure, then we shouldn't split.
bool only_pure_nodes = true;
for (int idx = 0; idx < split_info.children_size(); idx += 2) {
if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
only_pure_nodes = false;
break;
}
}
if (only_pure_nodes) {
VLOG(1) << "The oblivious split does not actually split anything.";
continue;
}
// Don't consider negative splits if we're pre-pruning the tree.
if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
gain < 0) {
continue;
}
// Take the split if we don't have a candidate yet.
auto best_split_it = best_splits->find(0);
if (best_split_it == best_splits->end()) {
best_splits->insert(std::make_pair(0, std::move(split)));
continue;
}
// Determine if we should update best split.
SplitCandidate& best_split = best_split_it->second;
trees::TreeNode current_node = split_info.split_node();
trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
if (TF_PREDICT_FALSE(gain == best_split.gain)) {
// Tie break on node case preferring simpler tree node types.
VLOG(2) << "Attempting to tie break with smaller node case. "
<< "(current split: " << current_node.node_case()
<< ", best split: " << best_node.node_case() << ")";
if (current_node.node_case() < best_node.node_case()) {
best_split = std::move(split);
} else if (current_node.node_case() == best_node.node_case()) {
// Tie break on handler Id.
VLOG(2) << "Tie breaking with higher handler Id. "
<< "(current split: " << handler_id
<< ", best split: " << best_split.handler_id << ")";
if (handler_id > best_split.handler_id) {
best_split = std::move(split);
}
}
} else if (gain > best_split.gain) {
best_split = std::move(split);
}
}
}
void UpdateTreeWeightsIfDropout(
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const uint64 dropout_seed) {
// It is possible that the tree was built with dropout. If it is the case,
// we need to adjust the tree weight, or bail out.
if (!dropout_was_applied_ ||
!ensemble_resource->LastTreeMetadata()->is_finalized()) {
return;
}
const int32 num_trees = ensemble_resource->num_trees();
// Based on seed, figure out what trees were dropped before.
std::unordered_set<int32> trees_not_to_drop;
if (center_bias_) {
trees_not_to_drop.insert(0);
}
// Last tree is the current tree that is built.
const int32 current_tree = num_trees - 1;
trees_not_to_drop.insert(current_tree);
// Since only chief builds the trees, we are sure that the other tree
// weights didn't change.
std::vector<float> weights = ensemble_resource->GetTreeWeights();
std::vector<int32> dropped_trees;
std::vector<float> dropped_trees_weights;
const auto dropout_status = DropoutUtils::DropOutTrees(
dropout_seed, dropout_config_, trees_not_to_drop, weights,
&dropped_trees, &dropped_trees_weights);
CHECK(dropout_status.ok())
<< "Can't figure out what trees were dropped out before, error is "
<< dropout_status.error_message();
// Now we have dropped trees, update their weights and the current tree
// weight.
if (!dropped_trees.empty()) {
std::vector<int32> increment_num_updates(num_trees, 0);
DropoutUtils::GetTreesWeightsForAddingTrees(
dropped_trees, dropped_trees_weights, current_tree,
1 /* only 1 tree was added */, &weights, &increment_num_updates);
// Update the weights and num of updates for trees.
for (int i = 0; i < num_trees; ++i) {
ensemble_resource->SetTreeWeight(i, weights[i],
increment_num_updates[i]);
}
}
}
// Helper method to update the growable tree which is by definition the last
// tree in the ensemble.
DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
const core::RefCountPtr<DecisionTreeEnsembleResource>& ensemble_resource,
const float learning_rate, const uint64 dropout_seed,
const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees();
if (num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
// Create a new tree with a no-op leaf.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
ensemble_resource->AddNewTree(learning_rate);
VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of "
<< num_trees + 1 << " trees.";
tree_config->add_nodes()->mutable_leaf();
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
ensemble_resource->LastTreeMetadata();
tree_metadata->set_is_finalized(max_tree_depth <= 1);
tree_metadata->set_num_tree_weight_updates(1);
} else {
// The growable tree is by definition the last tree in the ensemble.
boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
ensemble_resource->LastTreeMetadata();
const auto new_num_layers = tree_metadata->num_layers_grown() + 1;
VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
<< num_trees - 1 << " of ensemble of " << num_trees << " trees.";
// Update growable tree metadata.
tree_metadata->set_num_layers_grown(new_num_layers);
tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth);
}
UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed);
return ensemble_resource->LastTree();
}
// Helper method to merge leaf weights as the tree is being grown.
boosted_trees::trees::Leaf* MergeLeafWeights(
const boosted_trees::trees::Leaf& source,
boosted_trees::trees::Leaf* dest) {
// Resolve leaf merging method based on how the trees are being grown.
if (learner_config_.growing_mode() ==
boosted_trees::learner::LearnerConfig::WHOLE_TREE) {
// No merging occurs when building a whole tree at a time.
return dest;
}
if (dest->leaf_case() == boosted_trees::trees::Leaf::LEAF_NOT_SET) {
// No merging is required. Just copy the source weights;
*dest = source;
return dest;
}
// Handle leaf merging based on type.
switch (source.leaf_case()) {
case boosted_trees::trees::Leaf::kVector: {
// No-op if source is empty
const auto& src_vec = source.vector();
if (src_vec.value_size() == 0) {
break;
}
CHECK(source.leaf_case() == dest->leaf_case());
// Dense add leaf vectors.
auto* dst_vec = dest->mutable_vector();
CHECK(src_vec.value_size() == dst_vec->value_size());
for (size_t idx = 0; idx < source.vector().value_size(); ++idx) {
(*dst_vec->mutable_value()->Mutable(idx)) += src_vec.value(idx);
}
break;
}
case boosted_trees::trees::Leaf::kSparseVector: {
// No-op if source is empty
const auto& src_vec = source.sparse_vector();
CHECK(src_vec.value_size() == src_vec.index_size());
if (src_vec.value_size() == 0) {
break;
}
CHECK(source.leaf_case() == dest->leaf_case());
// Get mapping of dimension to value for destination.
std::unordered_map<int32, float> dst_map;
auto* dst_vec = dest->mutable_sparse_vector();
CHECK(dst_vec->value_size() == dst_vec->index_size());
dst_map.reserve(dst_vec->value_size());
for (size_t idx = 0; idx < dst_vec->value_size(); ++idx) {
dst_map[dst_vec->index(idx)] = dst_vec->value(idx);
}
// Sparse add source vector to destination vector.
for (size_t idx = 0; idx < src_vec.value_size(); ++idx) {
dst_map[src_vec.index(idx)] += src_vec.value(idx);
}
// Rebuild merged destination leaf.
dst_vec->clear_index();
dst_vec->clear_value();
for (const auto& entry : dst_map) {
dst_vec->add_index(entry.first);
dst_vec->add_value(entry.second);
}
break;
}
case boosted_trees::trees::Leaf::LEAF_NOT_SET: {
// No-op as there is nothing to merge.
break;
}
}
return dest;
}
// Helper method to split a tree node and append its respective
// leaf children given the split candidate.
void SplitTreeNode(
const int32 node_id, SplitCandidate* split,
DecisionTreeConfig* tree_config,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
// No-op if we have no real node.
CHECK(node_id < tree_config->nodes_size())
<< "Invalid node " << node_id << " to split.";
// Ensure new split node is valid.
CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
<< "Unexpected node type to split "
<< tree_config->nodes(node_id).node_case() << " for node_id " << node_id
<< ". Tree config: " << tree_config->DebugString();
// Add left leaf.
int32 left_id = tree_config->nodes_size();
(*tree_config->add_nodes()->mutable_leaf()) =
*MergeLeafWeights(tree_config->nodes(node_id).leaf(),
split->split_info.mutable_left_child());
// Add right leaf.
int32 right_id = tree_config->nodes_size();
(*tree_config->add_nodes()->mutable_leaf()) =
*MergeLeafWeights(tree_config->nodes(node_id).leaf(),
split->split_info.mutable_right_child());
// Link children and add them as new roots.
boosted_trees::trees::DecisionTree::LinkChildren(
{left_id, right_id}, split->split_info.mutable_split_node());
// Add split gain and, if needed, original leaf to node metadata.
TreeNodeMetadata* node_metadata =
split->split_info.mutable_split_node()->mutable_node_metadata();
node_metadata->set_gain(split->gain);
if (learner_config_.pruning_mode() ==
boosted_trees::learner::LearnerConfig::POST_PRUNE) {
(*node_metadata->mutable_original_leaf()) =
*tree_config->mutable_nodes(node_id)->mutable_leaf();
}
// Replace node in tree.
(*tree_config->mutable_nodes(node_id)) =
*split->split_info.mutable_split_node();
if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
resource->MaybeAddUsedHandler(split->handler_id);
}
}
void SplitTreeLayer(
SplitCandidate* split, DecisionTreeConfig* tree_config,
const core::RefCountPtr<DecisionTreeEnsembleResource>& resource) {
int depth = 0;
while (depth < tree_config->nodes_size() &&
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
depth++;
}
CHECK(tree_config->nodes_size() > 0)
<< "A tree must have at least one dummy leaf.";
// The number of new children.
int num_children = 1 << (depth + 1);
auto split_info = split->oblivious_split_info;
CHECK(num_children >= split_info.children_size())
<< "Too many new children, expected <= " << num_children << " and got "
<< split_info.children_size();
std::vector<trees::Leaf> new_leaves;
new_leaves.reserve(num_children);
int next_id = 0;
for (int idx = 0; idx < num_children / 2; idx++) {
trees::Leaf old_leaf =
*tree_config->mutable_nodes(depth + idx)->mutable_leaf();
// Check if a split was made for this leaf.
if (next_id < split_info.children_parent_id_size() &&
depth + idx == split_info.children_parent_id(next_id)) {
// Add left leaf.
new_leaves.push_back(*MergeLeafWeights(
old_leaf, split_info.mutable_children(2 * next_id)));
// Add right leaf.
new_leaves.push_back(*MergeLeafWeights(
old_leaf, split_info.mutable_children(2 * next_id + 1)));
next_id++;
} else {
// If there is no split for this leaf, just duplicate it.
new_leaves.push_back(old_leaf);
new_leaves.push_back(old_leaf);
}
}
CHECK(next_id == split_info.children_parent_id_size());
TreeNodeMetadata* split_metadata =
split_info.mutable_split_node()->mutable_node_metadata();
split_metadata->set_gain(split->gain);
TreeNode new_split = *split_info.mutable_split_node();
// Move old children to metadata.
for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
*new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
*tree_config->mutable_nodes(idx)->mutable_leaf();
}
// Add the new split to the tree_config in place before the children start.
*tree_config->mutable_nodes(depth) = new_split;
// Add the new children
int nodes_size = tree_config->nodes_size();
for (int idx = 0; idx < num_children; idx++) {
if (idx + depth + 1 < nodes_size) {
// Update leaves that were already there.
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
new_leaves[idx];
} else {
// Add new leaves.
*tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
}
}
}
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
// No-op if tree is empty.
if (tree_config->nodes_size() <= 0) {
return;
}
// Copy nodes to temp vector and clear original tree.
std::vector<TreeNode> tree_nodes;
tree_nodes.reserve(tree_config->nodes_size());
for (auto& node : (*tree_config->mutable_nodes())) {
tree_nodes.push_back(node);
node.Clear();
}
tree_config->clear_nodes();
// Prune the tree recursively starting from the root.
RecursivePruneTree(0, &tree_nodes);
// Rebuild compacted tree.
(*tree_config->add_nodes()) = tree_nodes[0];
std::unordered_map<size_t, size_t> nodes_map;
nodes_map[0] = 0;
for (size_t node_idx = 0; node_idx < tree_nodes.size(); ++node_idx) {
// Skip pruned nodes.
auto& original_node = tree_nodes[node_idx];
if (original_node.node_case() == TreeNode::NODE_NOT_SET) {
continue;
}
// Find node mapped in tree ensemble.
auto mapped_node_it = nodes_map.find(node_idx);
CHECK(mapped_node_it != nodes_map.end());
auto& mapped_node = (*tree_config->mutable_nodes(mapped_node_it->second));
// Get node children
auto children =
boosted_trees::trees::DecisionTree::GetChildren(original_node);
for (int32& child_idx : children) {
auto new_idx = tree_config->nodes_size();
(*tree_config->add_nodes()) = tree_nodes[child_idx];
nodes_map[child_idx] = new_idx;
child_idx = new_idx;
}
boosted_trees::trees::DecisionTree::LinkChildren(children, &mapped_node);
}
// Check if there are any nodes with gain left.
if (tree_config->nodes_size() == 1 &&
tree_config->nodes(0).node_metadata().gain() <= 0) {
// The whole tree should be pruned.
VLOG(2) << "No useful nodes left after post-pruning tree.";
tree_config->clear_nodes();
}
}
private:
boosted_trees::learner::LearnerConfig learner_config_;
int64 num_handlers_;
LearningRateDropoutDrivenConfig dropout_config_;
bool dropout_was_applied_;
bool center_bias_;
};
REGISTER_KERNEL_BUILDER(Name("GrowTreeEnsemble").Device(DEVICE_CPU),
GrowTreeEnsembleOp);
class TreeEnsembleStatsOp : public OpKernel {
public:
explicit TreeEnsembleStatsOp(OpKernelConstruction* const context)
: OpKernel(context) {}
void Compute(OpKernelContext* const context) override {
// Get decision tree ensemble.
core::RefCountPtr<DecisionTreeEnsembleResource> ensemble_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
&ensemble_resource));
tf_shared_lock l(*ensemble_resource->get_mutex());
// Get the stamp token.
const Tensor* stamp_token_t;
OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
int64 stamp_token = stamp_token_t->scalar<int64>()();
// Only the Chief should run this Op and it is guaranteed to be in
// a consistent state so the stamps must always match.
CHECK(ensemble_resource->is_stamp_valid(stamp_token));
const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config =
ensemble_resource->decision_tree_ensemble();
// Set tree stats.
Tensor* num_trees_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"num_trees", TensorShape({}), &num_trees_t));
Tensor* active_tree_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("active_tree", TensorShape({}),
&active_tree_t));
Tensor* attempted_tree_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("attempted_trees", TensorShape({}),
&attempted_tree_t));
const int num_trees = ensemble_resource->num_trees();
active_tree_t->scalar<int64>()() = num_trees;
num_trees_t->scalar<int64>()() =
(num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized())
? num_trees
: num_trees - 1;
attempted_tree_t->scalar<int64>()() =
ensemble_config.growing_metadata().num_trees_attempted();
// Set layer stats.
Tensor* num_layers_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"num_layers", TensorShape({}), &num_layers_t));
Tensor* active_layer_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("active_layer", TensorShape({}),
&active_layer_t));
Tensor* attempted_layers_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output("attempted_layers", TensorShape({}),
&attempted_layers_t));
int64 num_layers = 0;
for (const auto& tree_metadata : ensemble_config.tree_metadata()) {
num_layers += tree_metadata.num_layers_grown();
}
num_layers_t->scalar<int64>()() = num_layers;
int tree_metadata_size = ensemble_config.tree_metadata_size();
active_layer_t->scalar<int64>()() =
tree_metadata_size > 0
? ensemble_config.tree_metadata(tree_metadata_size - 1)
.num_layers_grown()
: 0;
attempted_layers_t->scalar<int64>()() =
ensemble_config.growing_metadata().num_layers_attempted();
}
};
REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStats").Device(DEVICE_CPU),
TreeEnsembleStatsOp);
} // namespace boosted_trees
} // namespace tensorflow

View File

@ -1,447 +0,0 @@
# Description:
# This directory contains common utilities used in boosted_trees.
load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_binary", "tf_cc_test")
package(
default_visibility = [
"//tensorflow/contrib/boosted_trees:__subpackages__",
"//tensorflow/contrib/boosted_trees:friends",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
# Utils
cc_library(
name = "utils",
srcs = [
"utils/batch_features.cc",
"utils/dropout_utils.cc",
"utils/examples_iterable.cc",
"utils/parallel_for.cc",
"utils/sparse_column_iterable.cc",
"utils/tensor_utils.cc",
],
hdrs = [
"utils/batch_features.h",
"utils/dropout_utils.h",
"utils/example.h",
"utils/examples_iterable.h",
"utils/macros.h",
"utils/optional_value.h",
"utils/parallel_for.h",
"utils/random.h",
"utils/sparse_column_iterable.h",
"utils/tensor_utils.h",
],
deps = [
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "sparse_column_iterable_test",
size = "small",
srcs = ["utils/sparse_column_iterable_test.cc"],
deps = [
":utils",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "examples_iterable_test",
size = "small",
srcs = ["utils/examples_iterable_test.cc"],
deps = [
":utils",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/algorithm:container",
],
)
tf_cc_test(
name = "example_test",
size = "small",
srcs = ["utils/example_test.cc"],
deps = [
":utils",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "batch_features_test",
size = "small",
srcs = ["utils/batch_features_test.cc"],
deps = [
":utils",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "dropout_utils_test",
size = "small",
srcs = ["utils/dropout_utils_test.cc"],
deps = [
":utils",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Models
cc_library(
name = "models",
srcs = ["models/multiple_additive_trees.cc"],
hdrs = ["models/multiple_additive_trees.h"],
deps = [
":trees",
":utils",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
],
)
tf_cc_test(
name = "multiple_additive_trees_test",
size = "small",
srcs = ["models/multiple_additive_trees_test.cc"],
deps = [
":batch_features_testutil",
":models",
":random_tree_gen",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Testutil
cc_library(
name = "batch_features_testutil",
testonly = 1,
srcs = ["testutil/batch_features_testutil.cc"],
hdrs = ["testutil/batch_features_testutil.h"],
deps = [
":utils",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:test",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "random_tree_gen",
srcs = ["testutil/random_tree_gen.cc"],
hdrs = ["testutil/random_tree_gen.h"],
deps = [
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:lib",
],
)
tf_cc_binary(
name = "random_tree_gen_main",
srcs = ["testutil/random_tree_gen_main.cc"],
deps = [
":random_tree_gen",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# Quantiles
cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
"quantiles/weighted_quantiles_buffer.h",
"quantiles/weighted_quantiles_stream.h",
"quantiles/weighted_quantiles_summary.h",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
],
)
tf_cc_test(
name = "weighted_quantiles_buffer_test",
size = "small",
srcs = ["quantiles/weighted_quantiles_buffer_test.cc"],
deps = [
":weighted_quantiles",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "weighted_quantiles_summary_test",
size = "small",
srcs = ["quantiles/weighted_quantiles_summary_test.cc"],
deps = [
":weighted_quantiles",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "weighted_quantiles_stream_test",
size = "small",
srcs = ["quantiles/weighted_quantiles_stream_test.cc"],
deps = [
":weighted_quantiles",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Trees
cc_library(
name = "trees",
srcs = ["trees/decision_tree.cc"],
hdrs = ["trees/decision_tree.h"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
],
)
tf_cc_test(
name = "trees_test",
size = "small",
srcs = ["trees/decision_tree_test.cc"],
deps = [
":trees",
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Learner/batch
py_library(
name = "base_split_handler",
srcs = ["learner/batch/base_split_handler.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/boosted_trees:batch_ops_utils_py",
"//tensorflow/python:control_flow_ops",
],
)
py_library(
name = "categorical_split_handler",
srcs = ["learner/batch/categorical_split_handler.py"],
srcs_version = "PY2AND3",
deps = [
":base_split_handler",
"//tensorflow/contrib/boosted_trees:split_handler_ops_py",
"//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
],
)
py_test(
name = "categorical_split_handler_test",
srcs = ["learner/batch/categorical_split_handler_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":categorical_split_handler",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
],
)
py_library(
name = "ordinal_split_handler",
srcs = ["learner/batch/ordinal_split_handler.py"],
srcs_version = "PY2AND3",
deps = [
":base_split_handler",
"//tensorflow/contrib/boosted_trees:quantile_ops_py",
"//tensorflow/contrib/boosted_trees:split_handler_ops_py",
"//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
],
)
py_test(
name = "ordinal_split_handler_test",
srcs = ["learner/batch/ordinal_split_handler_test.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [
":ordinal_split_handler",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_py",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
],
)
# Learner/Common
cc_library(
name = "class-partition-key",
hdrs = ["learner/common/accumulators/class-partition-key.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
],
)
cc_library(
name = "feature-stats-accumulator",
hdrs = ["learner/common/accumulators/feature-stats-accumulator.h"],
deps = [
":class-partition-key",
],
)
tf_cc_test(
name = "feature-stats-accumulator_test",
size = "small",
srcs = ["learner/common/accumulators/feature-stats-accumulator_test.cc"],
deps = [
":feature-stats-accumulator",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "example_partitioner",
srcs = ["learner/common/partitioners/example_partitioner.cc"],
hdrs = ["learner/common/partitioners/example_partitioner.h"],
deps = [
"//tensorflow/contrib/boosted_trees/lib:trees",
"//tensorflow/contrib/boosted_trees/lib:utils",
"//tensorflow/core:framework_headers_lib",
],
)
tf_cc_test(
name = "example_partitioner_test",
size = "small",
srcs = ["learner/common/partitioners/example_partitioner_test.cc"],
deps = [
":example_partitioner",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Learner/stochastic
cc_library(
name = "gradient-stats",
hdrs = ["learner/common/stats/gradient-stats.h"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
)
cc_library(
name = "node-stats",
hdrs = ["learner/common/stats/node-stats.h"],
deps = [
":gradient-stats",
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
)
cc_library(
name = "split-stats",
hdrs = ["learner/common/stats/split-stats.h"],
deps = [
":node-stats",
],
)
cc_library(
name = "feature-split-candidate",
hdrs = ["learner/common/stats/feature-split-candidate.h"],
deps = [
":split-stats",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
],
)
tf_cc_test(
name = "node-stats_test",
size = "small",
srcs = ["learner/common/stats/node-stats_test.cc"],
deps = [
":node-stats",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

Some files were not shown because too many files have changed in this diff Show More