Add correct dependencies to sdca ops to fix build breakage.

Change: 115408162
This commit is contained in:
A. Unique TensorFlower 2016-02-23 18:51:19 -08:00 committed by TensorFlower Gardener
parent 185cff7f41
commit 94a992cfc3
18 changed files with 1364 additions and 21 deletions

View File

@ -13,6 +13,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/util:util_py",
],
)

View File

@ -21,4 +21,5 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import layers
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import util

View File

@ -0,0 +1,85 @@
# Description:
# Contains ops to train linear models on top of TensorFlow.
# APIs here are meant to evolve over time.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
cc_library(
name = "sdca_kernel",
srcs = ["kernels/sdca_ops.cc"],
hdrs = ["kernels/logistic-loss.h"],
visibility = ["//visibility:private"],
deps = [
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all",
],
alwayslink = 1,
)
cc_library(
name = "sdca_ops",
srcs = ["ops/sdca_ops.cc"],
deps = [
"//tensorflow/core:framework",
],
alwayslink = 1,
)
tf_gen_op_libs(
op_lib_names = ["sdca_ops"],
)
tf_gen_op_wrapper_py(
name = "gen_sdca_ops_py",
out = "gen_sdca_ops.py",
deps = [":sdca_ops"],
)
py_library(
name = "sdca_ops_py",
srcs = [
"__init__.py",
"python/ops/sdca_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":gen_sdca_ops_py",
],
)
py_test(
name = "sdca_ops_test",
srcs = ["python/kernel_tests/sdca_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":sdca_kernel",
":sdca_ops",
":sdca_ops_py",
"//third_party/py/tensorflow",
"//tensorflow/core:all_kernels",
"//tensorflow/core:protos_all_py_pb2",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,21 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for training linear models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import *

View File

@ -0,0 +1,103 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_
#include <algorithm>
#include <cmath>
namespace tensorflow {
struct logistic_loss {
// Use an approximate step that is guaranteed to decrease the dual loss.
// Derivation of this is available in Page 14 Eq 16 of
// http://arxiv.org/pdf/1211.2717v1.pdf
inline static double ComputeUpdatedDual(const double label,
const double example_weight,
const double current_dual,
const double wx,
const double weighted_example_norm,
const double primal_loss,
const double dual_loss) {
const double ywx = label * wx;
// To avoid overflow, we compute derivative of logistic loss with respect to
// log-odds as follows.
double inverse_exp_term = 0;
if (ywx > 0) {
const double exp_minus_ywx = exp(-ywx);
inverse_exp_term = exp_minus_ywx / (1 + exp_minus_ywx);
} else {
inverse_exp_term = 1 / (1 + exp(ywx));
}
// f(a) = sup (a*x - f(x)) then a = f'(x), where a is the aproximate dual.
const double approximate_dual = inverse_exp_term * label;
const double delta_dual = approximate_dual - current_dual;
// Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
// when log-odds is zero.
const double gamma =
(wx == 0) ? 0.25 : (1 - 2 * inverse_exp_term) / (2 * ywx);
const double wx_dual = wx * current_dual * example_weight;
const double delta_dual_squared = delta_dual * delta_dual;
const double smooth_delta_dual_squared = delta_dual_squared * gamma * 0.5;
double multiplier =
(primal_loss + dual_loss + wx_dual + smooth_delta_dual_squared) /
std::max(1.0,
delta_dual_squared *
(gamma +
weighted_example_norm * example_weight * example_weight));
// Multiplier must be in the range [0, 1].
multiplier = std::max(std::min(1.0, multiplier), 0.0);
return current_dual + delta_dual * multiplier;
}
// Dual of logisitic loss function.
// https://en.wikipedia.org/wiki/Convex_conjugate
inline static double ComputeDualLoss(const double current_dual,
const double example_label,
const double example_weight) {
// Dual of the logistic loss function is
// ay * log(ay) + (1-ay) * log (1-ay), where a is the dual variable.
const double ay = current_dual * example_label;
const double log_ay = (ay > 0) ? log(ay) : 0;
const double one_minus_ay = 1 - ay;
const double log_one_minus_ay = (one_minus_ay > 0) ? log(one_minus_ay) : 0;
return ((ay * log_ay) + (one_minus_ay * log_one_minus_ay)) * example_weight;
}
// Logistic loss for binary classification.
// https://en.wikipedia.org/wiki/Loss_functions_for_classification
inline static double ComputePrimalLoss(const double wx,
const double example_label,
const double example_weight) {
// Logistic loss:
// log(1 + e^(-ywx))
// log(e^0 + e^(-ywx))
// a + log(e^(0-a) + e^(-ywx - a)), where a is max(0, -ywx)
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
const double y_wx = example_label * wx;
if (y_wx > 0) {
// 0 + log(e^(0) + e^(-ywx - 0))
// log(1 + e^(-ywx))
return log(1 + exp(-y_wx)) * example_weight;
}
// -ywx + log(e^(ywx) + e^(-ywx + ywx))
// log(e^(ywx) + e^(0)) - ywx
// log(1 + e^(ywx)) - ywx
return (log(1 + exp(y_wx)) - y_wx) * example_weight;
}
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_

View File

@ -0,0 +1,458 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/sdca_ops.cc.
#define EIGEN_USE_THREADS
#include <atomic>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
// A feature group of a single example by this struct.
struct PerExampleSparseIndicesWeights {
// N X 2 matrix with (example_id, feature_indices).
tensorflow::TTypes<const int64>::UnalignedMatrix indices;
// N X 1 vector with feature weights.
tensorflow::TTypes<float>::UnalignedVec values;
// sum squared norm of the features.
double norm;
};
struct Regularizations {
float symmetric_l1 = 0;
float symmetric_l2 = 0;
};
struct RegularizationLoss {
double l1_loss = 0;
double l2_loss = 0;
};
struct PerExampleData {
double wx = 0;
double norm = 0;
};
// Tensor vector of floats which holds the weights.
using Weights = TTypes<float>::Vec;
// Weights associated with feature group, such that size of WeightsByIndex is
// the number of feature groups.
using WeightsByIndex = std::vector<Weights>;
// SparseExamples represent sparse feature groups of each example.
using SparseExamples =
vector<std::unique_ptr<const PerExampleSparseIndicesWeights>>;
// SparseExamples associated with each sparse feature group.
using SparseExamplesByIndex = vector<SparseExamples>;
using DenseFeaturesByIndex = vector<tensorflow::TTypes<const float>::Vec>;
// Compute the shrinkage factor for proximal sdca.
inline double ShrinkageFactor(const Regularizations& regularizations) {
return regularizations.symmetric_l1 / regularizations.symmetric_l2;
}
// Proximal SDCA shrinking for L1 regularization.
inline double Shrink(const double weight, const double shrink_by) {
const double shrink_weight = std::max(std::abs(weight) - shrink_by, 0.0);
if (shrink_weight > 0.0) {
return std::copysign(shrink_weight, weight);
}
return 0.0;
}
// Compute L1 and L2 regularization loss.
inline RegularizationLoss ComputeRegularizationLoss(
const WeightsByIndex& sparse_weights_by_index,
const WeightsByIndex& dense_weights_by_index,
const Regularizations& regularizations) {
RegularizationLoss result;
const double shrink_by = ShrinkageFactor(regularizations);
auto accumulate_regularization_loss = [&](const double w) {
const double sw = std::abs(Shrink(w, shrink_by));
result.l1_loss += sw;
result.l2_loss += sw * sw;
};
for (auto& sparse_weights : sparse_weights_by_index) {
for (size_t i = 0; i < sparse_weights.size(); ++i) {
accumulate_regularization_loss(sparse_weights(i));
}
}
for (auto& dense_weights : dense_weights_by_index) {
accumulate_regularization_loss(dense_weights(0));
}
result.l1_loss *= regularizations.symmetric_l1;
result.l2_loss *= regularizations.symmetric_l2;
return result;
}
// Compute PerExampleData which contains the logits, and weighted example norm
// for a given example_id. Norm is weighted by 1/(lambda*N).
inline PerExampleData ComputeWxAndWeightedExampleNorm(
const int64 example_id, const WeightsByIndex& sparse_weights_by_index,
const SparseExamplesByIndex& sparse_examples_by_index,
const WeightsByIndex& dense_weights_by_index,
const DenseFeaturesByIndex& dense_features,
const Regularizations& regularizations) {
PerExampleData result;
const double shrink_by = ShrinkageFactor(regularizations);
for (size_t i = 0; i < sparse_examples_by_index.size(); ++i) {
const SparseExamples& sparse_indices_values = sparse_examples_by_index[i];
const Weights sparse_weights = sparse_weights_by_index[i];
if (sparse_indices_values[example_id]) {
const auto indices = sparse_indices_values[example_id]->indices;
const auto values = sparse_indices_values[example_id]->values;
for (size_t dim = 0; dim < indices.dimension(0); ++dim) {
result.wx +=
Shrink(sparse_weights(indices(dim, 1)), shrink_by) * values(dim);
}
result.norm += sparse_indices_values[example_id]->norm;
}
}
for (size_t i = 0; i < dense_features.size(); ++i) {
const auto dense_values = dense_features[i];
const Weights dense_weights = dense_weights_by_index[i];
result.wx += Shrink(dense_weights(0), shrink_by) * dense_values(example_id);
result.norm += dense_values(example_id) * dense_values(example_id);
}
result.norm /= regularizations.symmetric_l2;
return result;
}
// Apply L1 regularization on the weights,
void ShrinkWeights(const Regularizations& regularizations,
WeightsByIndex* const sparse_weights_by_index,
WeightsByIndex* const dense_weights_by_index) {
const double shrink_by = ShrinkageFactor(regularizations);
for (auto& sparse_weights : *sparse_weights_by_index) {
for (size_t i = 0; i < sparse_weights.size(); ++i) {
sparse_weights(i) = Shrink(sparse_weights(i), shrink_by);
}
}
for (auto& dense_weights : *dense_weights_by_index) {
dense_weights(0) = Shrink(dense_weights(0), shrink_by);
}
}
void UpdateWeights(const int64 example_id,
const SparseExamplesByIndex& sparse_examples_by_index,
const DenseFeaturesByIndex& dense_features,
const double bounded_dual_delta,
const double l2_regularization,
WeightsByIndex* const sparse_weights_by_index,
WeightsByIndex* const dense_weights_by_index) {
for (size_t i = 0; i < sparse_examples_by_index.size(); ++i) {
const SparseExamples& sparse_indices_values = sparse_examples_by_index[i];
Weights sparse_weights = (*sparse_weights_by_index)[i];
if (sparse_indices_values[example_id]) {
const auto indices = sparse_indices_values[example_id]->indices;
const auto values = sparse_indices_values[example_id]->values;
for (size_t dim = 0; dim < indices.dimension(0); ++dim) {
// TODO(rohananil): Atomic updates provide better convergence guarantees
// However, casting float to atomic<float> is UB. We may consider
// sharded set of locks, or bring primal-dual relationship to consistent
// state after several epochs.
sparse_weights(indices(dim, 1)) +=
bounded_dual_delta * values(dim) / l2_regularization;
}
}
}
for (size_t i = 0; i < dense_features.size(); ++i) {
const auto dense_values = dense_features[i];
Weights dense_weights = (*dense_weights_by_index)[i];
// TODO(rohananil): Atomic updates provide better convergence gaurantees
// However, casting float to atomic<float> is UB. We may consider
// sharded set of locks, or bring primal-dual relationship to consistent
// state after several epochs.
dense_weights(0) +=
bounded_dual_delta * dense_values(example_id) / l2_regularization;
}
}
// Atomically add a double to a std::atomic<double>.
inline void AtomicAdd(const double value, std::atomic<double>* const dst) {
// We use a strong version of compare-exchange, as weak version can spuriously
// fail.
for (double c = dst->load(); !dst->compare_exchange_strong(c, c + value);) {
}
}
} // namespace
class SdcaSolver : public OpKernel {
public:
explicit SdcaSolver(OpKernelConstruction* context) : OpKernel(context) {
string loss_type;
OP_REQUIRES_OK(context, context->GetAttr("LossType", &loss_type));
if (loss_type == "logistic_loss") {
compute_dual_loss_ = logistic_loss::ComputeDualLoss;
compute_primal_loss_ = logistic_loss::ComputePrimalLoss;
compute_dual_update_ = logistic_loss::ComputeUpdatedDual;
}
OP_REQUIRES_OK(
context, context->GetAttr("NumSparseFeatures", &num_sparse_features_));
OP_REQUIRES_OK(context,
context->GetAttr("NumDenseFeatures", &num_dense_features_));
OP_REQUIRES(
context, num_sparse_features_ + num_dense_features_ > 0,
errors::InvalidArgument("Requires at least one feature to train."));
OP_REQUIRES_OK(context,
context->GetAttr("L1", &regularizations_.symmetric_l1));
OP_REQUIRES_OK(context,
context->GetAttr("L2", &regularizations_.symmetric_l2));
OP_REQUIRES_OK(context, context->GetAttr("DualityGapThreshold",
&duality_gap_threshold_));
}
void Compute(OpKernelContext* context) override {
const Tensor* example_weights_t;
OP_REQUIRES_OK(context,
context->input("example_weights", &example_weights_t));
const auto example_weights = example_weights_t->vec<float>();
Tensor primal_loss_t;
OP_REQUIRES_OK(context,
context->mutable_input("primal_loss", &primal_loss_t,
/*lock_held=*/true));
auto primal_loss = primal_loss_t.scalar<double>();
const int64 num_examples = example_weights.size();
Eigen::Tensor<float, 0, Eigen::RowMajor> example_weights_sum;
example_weights_sum.device(context->eigen_cpu_device()) =
example_weights.sum();
const float weighted_examples = example_weights_sum();
OP_REQUIRES(context, weighted_examples > 0.0,
errors::InvalidArgument("No weighted examples in ",
num_examples, " training examples"));
// We scale it up by weighted examples.
regularizations_.symmetric_l1 =
regularizations_.symmetric_l1 * weighted_examples;
regularizations_.symmetric_l2 =
std::max(regularizations_.symmetric_l2 * weighted_examples, 1.0f);
OpInputList sparse_features_indices_inputs;
OP_REQUIRES_OK(context,
context->input_list("sparse_features_indices",
&sparse_features_indices_inputs));
OpInputList sparse_features_values;
OP_REQUIRES_OK(context, context->input_list("sparse_features_values",
&sparse_features_values));
SparseExamplesByIndex sparse_examples_by_index(num_sparse_features_);
// Goes through the entire training set once, so that we create per example
// structures. We use it downstream for randomizating order of the examples.
auto do_parse = [&](const int64 begin, const int64 end) {
// We set the order as [0, 1], which specifies that its row-major
// increasing. This means first column has ids which is
// lexicographically increasing.
static const int64 kIndicesDims = 2;
gtl::InlinedVector<int64, 8> order(kIndicesDims);
std::iota(order.begin(), order.end(), 0);
for (size_t i = begin; i < end; ++i) {
OP_REQUIRES(context, sparse_features_indices_inputs[i].shape().dims() ==
kIndicesDims,
errors::InvalidArgument(
"Indices should have exactly 2 dimensions"));
tensorflow::sparse::SparseTensor st(
sparse_features_indices_inputs[i], sparse_features_values[i],
sparse_features_indices_inputs[i].shape(), order);
sparse_examples_by_index[i] = SparseExamples(num_examples);
for (const auto& example_group : st.group({0})) {
const int64 example_id = example_group.indices()(0, 0);
const Eigen::Tensor<float, 0, Eigen::RowMajor> norm =
example_group.values<float>().square().sum();
sparse_examples_by_index[i][example_id].reset(
new PerExampleSparseIndicesWeights{example_group.indices(),
example_group.values<float>(),
norm()});
}
}
};
{
const DeviceBase::CpuWorkerThreads* const worker_threads =
context->device()->tensorflow_cpu_worker_threads();
// For each column, the cost of parsing it O(num_examples). We use
// num_examples here, as empircally Shard() creates the right amount of
// threads based on the problem size.
// TODO(rohananil): Tune this as a function of dataset size.
const int64 kCostPerUnit = num_examples;
Shard(worker_threads->num_threads, worker_threads->workers,
num_sparse_features_, kCostPerUnit, do_parse);
}
OpInputList dense_features_inputs;
OP_REQUIRES_OK(
context, context->input_list("dense_features", &dense_features_inputs));
DenseFeaturesByIndex dense_features;
for (auto& dense_features_input : dense_features_inputs) {
dense_features.emplace_back(dense_features_input.vec<float>());
}
const Tensor* example_labels_t;
OP_REQUIRES_OK(context,
context->input("example_labels", &example_labels_t));
const auto example_labels = example_labels_t->vec<float>();
Tensor dual_variables_t;
OP_REQUIRES_OK(context,
context->mutable_input("dual_variables", &dual_variables_t,
/*lock_held=*/true));
auto dual_variables = dual_variables_t.vec<float>();
OpMutableInputList sparse_weights_by_index_inputs;
OP_REQUIRES_OK(
context, context->mutable_input_list("sparse_weights_by_index",
&sparse_weights_by_index_inputs));
WeightsByIndex sparse_weights_by_index;
for (size_t i = 0; i < sparse_weights_by_index_inputs.size(); ++i) {
sparse_weights_by_index.emplace_back(
sparse_weights_by_index_inputs.at(i, /*lock_held=*/true)
.flat<float>());
}
OpMutableInputList dense_weights_by_index_inputs;
OP_REQUIRES_OK(context,
context->mutable_input_list("dense_weights_by_index",
&dense_weights_by_index_inputs));
WeightsByIndex dense_weights_by_index;
for (size_t i = 0; i < dense_weights_by_index_inputs.size(); ++i) {
dense_weights_by_index.emplace_back(
dense_weights_by_index_inputs.at(i, /*lock_held=*/true)
.flat<float>());
}
vector<int64> example_ids(num_examples);
std::iota(example_ids.begin(), example_ids.end(), 0);
std::random_device random_device;
std::mt19937 random_generator(random_device());
std::atomic<double> total_approx_duality_gap(
std::numeric_limits<double>::max());
std::atomic<double> total_primal_loss(0);
// Break when duality gap |P(w) - D(alpha)| is less than
// duality_gap_threshold_
while ((total_approx_duality_gap / weighted_examples) >
duality_gap_threshold_) {
// Reset and add everything.
total_approx_duality_gap = 0;
total_primal_loss = 0;
std::shuffle(example_ids.begin(), example_ids.end(), random_generator);
auto do_update = [&](const int64 begin, const int64 end) {
double approx_duality_gap = 0;
for (int64 offset = begin; offset < end; ++offset) {
// Get example id, label, and weight.
const int64 example_id = example_ids[offset];
OP_REQUIRES(context, !(example_labels(example_id) > 0 &&
example_labels(example_id) < 1),
errors::InvalidArgument(
"Fractional labels not supported right now. "
"Found example with label: ",
example_labels(example_id)));
const float example_label = example_labels(example_id) == 0 ? -1 : 1;
const double current_dual = dual_variables(example_id);
const double example_weight = example_weights(example_id);
// Compute wx, example norm weighted by regularization, dual loss,
// primal loss.
const PerExampleData per_example_data =
ComputeWxAndWeightedExampleNorm(
example_id, sparse_weights_by_index, sparse_examples_by_index,
dense_weights_by_index, dense_features, regularizations_);
const double dual_loss = compute_dual_loss_(
current_dual, example_label, example_weight);
const double primal_loss = compute_primal_loss_(
per_example_data.wx, example_label, example_weight);
approx_duality_gap += dual_loss + primal_loss;
// Update dual variable.
dual_variables(example_id) = compute_dual_update_(
example_label, example_weight, current_dual, per_example_data.wx,
per_example_data.norm, primal_loss, dual_loss);
// Compute new weights.
const double bounded_dual_delta =
(dual_variables(example_id) - current_dual) * example_weight;
UpdateWeights(example_id, sparse_examples_by_index, dense_features,
bounded_dual_delta, regularizations_.symmetric_l2,
&sparse_weights_by_index, &dense_weights_by_index);
AtomicAdd(approx_duality_gap, &total_approx_duality_gap);
AtomicAdd(primal_loss, &total_primal_loss);
}
// TODO(rohananil): We may in the future want to make the primal-dual
// relationship consistent as our current updates are not transactional.
};
const DeviceBase::CpuWorkerThreads* const worker_threads =
context->device()->tensorflow_cpu_worker_threads();
const int64 kCostPerUnit =
100000 * (num_sparse_features_ + num_dense_features_);
Shard(worker_threads->num_threads, worker_threads->workers, num_examples,
kCostPerUnit, do_update);
const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
sparse_weights_by_index, dense_weights_by_index, regularizations_);
total_approx_duality_gap.store(total_approx_duality_gap.load() +
regularization_loss.l1_loss +
regularization_loss.l2_loss);
primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
regularization_loss.l2_loss) /
weighted_examples;
}
ShrinkWeights(regularizations_, &sparse_weights_by_index,
&dense_weights_by_index);
}
private:
std::function<decltype(logistic_loss::ComputeDualLoss)> compute_dual_loss_;
std::function<decltype(logistic_loss::ComputePrimalLoss)>
compute_primal_loss_;
std::function<decltype(logistic_loss::ComputeUpdatedDual)>
compute_dual_update_;
int64 num_sparse_features_;
int64 num_dense_features_;
Regularizations regularizations_;
float duality_gap_threshold_;
};
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
// --------------------------------------------------------------------------
REGISTER_OP("SdcaSolver")
.Attr("LossType: {'logistic_loss'}")
.Attr("NumSparseFeatures: int >= 0")
.Attr("NumDenseFeatures: int >= 0")
.Attr("L1: float >= 0")
.Attr("L2: float >= 0")
.Attr("DualityGapThreshold: float = 0.01")
.Input("sparse_features_indices: NumSparseFeatures * int64")
.Input("sparse_features_values: NumSparseFeatures * float")
.Input("dense_features: NumDenseFeatures * float")
.Input("example_weights: float")
.Input("example_labels: float")
.Input("dual_variables: Ref(float)")
.Input(
"sparse_weights_by_index: Ref(NumSparseFeatures * "
"float)")
.Input(
"dense_weights_by_index: Ref(NumDenseFeatures * "
"float)")
.Input("primal_loss: Ref(double)")
.Doc(R"doc(
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
L1 + L2 regularization. As global optimization objective is strongly-convex, the
optimizer optimizes the dual objective at each step. The optimizer applies each
update one example at a time. Examples are sampled uniformly, and the optimizer
is learning rate free and enjoys linear convergence rate.
Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.
2012arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf
LossType: Type of the primal loss. Only logistic_loss is supported.
NumSparseFeatures: Number of sparse features to train on.
NumDenseFeatures: Number of dense features to train on.
L1: Per example symmetric l1 regularization strength.
L2: Per example symmetric l2 regularization strength.
DualityGapThreshold: Gap threshold at which we should stop training.
sparse_features_indices: a list of matrices with two columns that contain
example_indices, and feature_indices.
sparse_features_values: a list of vectors which contains feature value
associated with each feature index.
dense_features: a list of vectors which contains the dense feature values.
example_weights: a vector which contains the example weight associated with
each example.
example_labels: a vector which contains the example label/target asscociated
with each example.
dual_variables: a vector which contains the dual variable asscociated with each
example.
sparse_weights_by_index: a list of vectors where each value is the weight
associated with a feature index.
dense_weights_by_index: a list of vectors where the value is the weight
associated with the dense feature.
)doc");
} // namespace tensorflow

View File

@ -0,0 +1,269 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SdcaModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.platform import googletest
# pylint: disable=invalid-name
SdcaModel = tf.contrib.linear_optimizer.SdcaModel
def make_example_proto(feature_dict, target):
e = tf.train.Example()
features = e.features
features.feature['target'].float_list.value.append(target)
for key, values in feature_dict.iteritems():
features.feature[key + '_indices'].int64_list.value.extend(values)
features.feature[key + '_values'].float_list.value.extend([1.0] *
len(values))
return e
def make_example_dict(example_protos, example_weights):
def parse_examples(example_protos):
features = {
'target': tf.FixedLenFeature(shape=[1],
dtype=tf.float32,
default_value=0),
'age_indices': tf.VarLenFeature(dtype=tf.int64),
'age_values': tf.VarLenFeature(dtype=tf.float32),
'gender_indices': tf.VarLenFeature(dtype=tf.int64),
'gender_values': tf.VarLenFeature(dtype=tf.float32)
}
return tf.parse_example(
[e.SerializeToString() for e in example_protos], features)
# TODO(rohananil): This converts two sparse tensors, into one sparse feature
# tensor. Use the tf core op once its merged in.
def sf_from_st(ids, weights):
example_indices, _ = tf.split(1, 2, ids.indices)
feature_indices = tf.expand_dims(ids.values, 1)
indices = tf.concat(1, [example_indices, feature_indices])
return tf.SparseTensor(indices, weights.values, ids.shape)
parsed = parse_examples(example_protos)
return dict(sparse_features=[
sf_from_st(parsed['age_indices'], parsed['age_values']), sf_from_st(
parsed[
'gender_indices'], parsed['gender_values'])
],
dense_features=[],
example_weights=example_weights,
example_labels=tf.reshape(parsed['target'], [-1]))
def make_variable_dict(examples_dict, max_age, max_gender):
# TODO(dbaylor): Figure out how to derive max_age & max_gender from
# examples_dict.
age_weights = tf.Variable(tf.zeros([max_age + 1], dtype=tf.float32))
gender_weights = tf.Variable(tf.zeros([max_gender + 1], dtype=tf.float32))
dual_variables = tf.Variable(tf.zeros_like(examples_dict['example_labels'],
dtype=tf.float32))
training_log_loss = tf.Variable(tf.zeros([], dtype=tf.float64))
return dict(sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[],
dual=dual_variables,
training_log_loss=training_log_loss)
class SdcaOptimizerTest(TensorFlowTestCase):
def testSimple(self):
# Setup test data
example_protos = [
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
]
example_weights = [1.0, 1.0]
with self.test_session(use_gpu=False):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(examples, 1, 1)
options = dict(symmetric_l2_regularization=0.5,
symmetric_l1_regularization=0,
loss_type='logistic_loss',
prior=0.0)
tf.initialize_all_variables().run()
lr = SdcaModel(examples, variables, options)
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
self.assertAllClose(0.693147, unregularized_loss.eval())
self.assertAllClose(0.693147, loss.eval())
lr.minimize().run()
self.assertAllClose(0.395226, unregularized_loss.eval())
self.assertAllClose(0.657446, loss.eval())
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
self.assertAllEqual([0, 1], predicted_labels.eval())
def testSomeUnweightedExamples(self):
# Setup test data with 4 examples, but should produce the same
# results as testSimple.
example_protos = [
# Will be used.
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
# Will be ignored.
make_example_proto(
{'age': [1],
'gender': [0]}, 0),
# Will be used.
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
# Will be ignored.
make_example_proto(
{'age': [1],
'gender': [0]}, 1),
]
example_weights = [1.0, 0.0, 1.0, 0.0]
with self.test_session(use_gpu=False):
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(examples, 1, 1)
options = dict(symmetric_l2_regularization=0.25,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
tf.initialize_all_variables().run()
lr = SdcaModel(examples, variables, options)
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
self.assertAllClose(0.395226, unregularized_loss.eval())
self.assertAllClose(0.526336, loss.eval())
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())
def testNoWeightedExamples(self):
# Setup test data with 1 positive, and 1 negative example.
example_protos = [
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
]
# Zeroed out example weights.
example_weights = [0.0, 0.0]
with self.test_session(use_gpu=False):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(examples, 1, 1)
options = dict(symmetric_l2_regularization=0.5,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
tf.initialize_all_variables().run()
lr = SdcaModel(examples, variables, options)
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
with self.assertRaisesOpError(
'No weighted examples in 2 training examples'):
lr.minimize().run()
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
def testImbalanced(self):
# Setup test data with 1 positive, and 3 negative examples.
example_protos = [
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
make_example_proto(
{'age': [2],
'gender': [0]}, 0),
make_example_proto(
{'age': [3],
'gender': [0]}, 0),
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
]
example_weights = [1.0, 1.0, 1.0, 1.0]
with self.test_session(use_gpu=False):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(examples, 3, 1)
options = dict(symmetric_l2_regularization=0.25,
symmetric_l1_regularization=0,
loss_type='logistic_loss',
prior=-1.09861)
tf.initialize_all_variables().run()
lr = SdcaModel(examples, variables, options)
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
self.assertAllClose(0.331710,
unregularized_loss.eval(),
rtol=1e-2,
atol=1e-2)
self.assertAllClose(0.591295, loss.eval(), rtol=1e-2, atol=1e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval())
def testImbalancedWithExampleWeights(self):
# Setup test data with 1 positive, and 3 negative examples.
example_protos = [
make_example_proto(
{'age': [0],
'gender': [0]}, 0),
make_example_proto(
{'age': [1],
'gender': [1]}, 1),
]
example_weights = [3.0, 1.0]
with self.test_session(use_gpu=False):
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(examples, 1, 1)
options = dict(symmetric_l2_regularization=0.25,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
tf.initialize_all_variables().run()
lr = SdcaModel(examples, variables, options)
unregularized_loss = lr.unregularized_loss(examples)
loss = lr.regularized_loss(examples)
prediction = lr.predictions(examples)
lr.minimize().run()
self.assertAllClose(0.266189,
unregularized_loss.eval(),
rtol=1e-2,
atol=1e-2)
self.assertAllClose(0.571912, loss.eval(), rtol=1e-2, atol=1e-2)
predicted_labels = tf.cast(
tf.greater_equal(prediction,
tf.ones_like(prediction) * 0.5), tf.float32)
self.assertAllEqual([0, 1], predicted_labels.eval())
if __name__ == '__main__':
googletest.main()

View File

@ -0,0 +1,275 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Proximal stochastic dual coordinate ascent optimizer for linear models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
# pylint: disable=wildcard-import
from tensorflow.contrib.linear_optimizer.gen_sdca_ops import *
class SdcaModel(object):
"""Stochastic dual coordinate ascent solver for linear models.
This class currently only supports a single machine (multi-threaded)
implementation. We expect the data, and weights to fit in a single machine.
Loss functions supported:
* Binary logistic loss
This class defines an optimizer API to train a linear model.
### Usage
```python
# Create a solver with the desired parameters.
lr = tf.contrib.linear_optimizer.SdcaModel(examples, variables, options)
opt_op = lr.minimize()
predictions = lr.predictions(examples)
# Primal loss + L1 loss + L2 loss.
regularized_loss = lr.regularized_loss(examples)
# Primal loss only
unregularized_loss = lr.unregularized_loss(examples)
examples: {
sparse_features: list of SparseTensors of value type float32.
dense_features: list of dense tensors of type float32.
example_labels: a tensor of of shape [Num examples]
example_weights: a tensor of shape [Num examples]
}
variables: {
sparse_features_weights: list of tensors of shape [vocab size]
dense_features_weights: list of tensors of shape [1]
dual: tensor of shape [Num examples]
}
options: {
symmetric_l1_regularization: 0.0
symmetric_l2_regularization: 1.0
loss_type: "logistic_loss"
}
```
In the training program you will just have to run the returned Op from
minimize().
```python
# Execute opt_op once to perform training, which continues until
convergence.
The op makes use of duality gap as a certificate for termination. Duality
gap is set to 0.01 as default.
opt_op.run()
```
"""
def __init__(self, examples, variables, options):
"""Create a new sdca optimizer."""
if not examples or not variables or not options:
raise ValueError('All arguments must be specified.')
if options['loss_type'] != 'logistic_loss':
raise ValueError('Optimizer only supports logistic regression (for now).')
self._assertSpecified(
['example_labels', 'example_weights', 'sparse_features',
'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
self._assertSpecified(
['sparse_features_weights', 'dense_features_weights',
'training_log_loss', 'dual'], variables)
self._assertList(
['sparse_features_weights', 'dense_features_weights'], variables)
self._assertSpecified(
['loss_type', 'symmetric_l2_regularization',
'symmetric_l1_regularization'], options)
self._examples = examples
self._variables = variables
self._options = options
self._training_log_loss = tf.convert_to_tensor(
self._variables['training_log_loss'],
as_ref=True)
def _assertSpecified(self, items, check_in):
for x in items:
if check_in[x] is None:
raise ValueError(check_in[x] + ' must be specified.')
def _assertList(self, items, check_in):
for x in items:
if not isinstance(check_in[x], list):
raise ValueError(x + ' must be a list.')
def _l1_loss(self):
""""Computes the l1 loss of the model."""
with tf.name_scope('l1_loss'):
sparse_weights = self._convert_n_to_tensor(self._variables[
'sparse_features_weights'])
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l1 = self._options['symmetric_l1_regularization']
loss = 0
for w in sparse_weights:
loss += l1 * tf.reduce_sum(tf.abs(w))
for w in dense_weights:
loss += l1 * tf.reduce_sum(tf.abs(w))
return loss
def _l2_loss(self):
""""Computes the l1 loss of the model."""
with tf.name_scope('l2_loss'):
sparse_weights = self._convert_n_to_tensor(self._variables[
'sparse_features_weights'])
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l2 = self._options['symmetric_l2_regularization']
loss = 0
for w in sparse_weights:
loss += l2 * tf.reduce_sum(tf.square(w))
for w in dense_weights:
loss += l2 * tf.reduce_sum(tf.square(w))
return loss
def _logits(self, examples):
"""Compute logits for each example."""
with tf.name_scope('logits'):
sparse_variables = self._convert_n_to_tensor(self._variables[
'sparse_features_weights'])
logits = 0
for st_i, sv in zip(examples['sparse_features'], sparse_variables):
ei, fi = tf.split(1, 2, st_i.indices)
ei = tf.reshape(ei, [-1])
fi = tf.reshape(fi, [-1])
fv = tf.reshape(st_i.values, [-1])
# TODO(rohananil): This does not work if examples have empty
# features.
logits += tf.segment_sum(
tf.mul(
tf.gather(sv, fi), fv), tf.reshape(ei, [-1]))
dense_features = self._convert_n_to_tensor(examples['dense_features'])
dense_variables = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
for i in xrange(len(dense_variables)):
logits += dense_features[i] * dense_variables[i]
return logits
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
return [tf.convert_to_tensor(x, as_ref=as_ref) for x in input_list]
def predictions(self, examples):
"""Add operations to compute predictions by the model.
Args:
examples: Examples to compute prediction on.
Returns:
An Operation that computes the predictions for examples. For logistic
loss
output is a tensor with sigmoid output.
Raises:
ValueError: if examples are not well defined.
"""
self._assertSpecified(
['example_weights', 'sparse_features', 'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
with tf.name_scope('sdca/prediction'):
logits = self._logits(examples)
# TODO(rohananil): Change prediction when supporting linear
# regression.
return tf.sigmoid(logits)
def minimize(self):
"""Add operations to train a linear model by minimizing the loss function.
Returns:
An Operation that updates the variables passed in the constructor.
"""
with tf.name_scope('sdca/minimize'):
sparse_features_indices = []
sparse_features_weights = []
for sf in self._examples['sparse_features']:
sparse_features_indices.append(ops.convert_to_tensor(sf.indices))
sparse_features_weights.append(ops.convert_to_tensor(sf.values))
return sdca_solver(
sparse_features_indices,
sparse_features_weights,
self._convert_n_to_tensor(self._examples['dense_features']),
tf.convert_to_tensor(self._examples['example_weights']),
tf.convert_to_tensor(self._examples['example_labels']),
tf.convert_to_tensor(self._variables['dual'],
as_ref=True),
self._convert_n_to_tensor(self._variables[
'sparse_features_weights'],
as_ref=True),
self._convert_n_to_tensor(self._variables['dense_features_weights'],
as_ref=True),
self._training_log_loss,
L1=self._options['symmetric_l1_regularization'],
L2=self._options['symmetric_l2_regularization'],
LossType=self._options['loss_type'])
def unregularized_loss(self, examples):
"""Add operations to compute the loss (without the regularization loss).
Args:
examples: Examples to compute unregularized loss on.
Returns:
An Operation that computes mean (unregularized) loss for given set of
examples.
Raises:
ValueError: if examples are not well defined.
"""
self._assertSpecified(
['example_labels', 'example_weights', 'sparse_features',
'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
with tf.name_scope('sdca/unregularized_loss'):
logits = self._logits(examples)
# TODO(rohananil): Change loss when supporting linear regression.
return tf.reduce_sum(tf.mul(
tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.convert_to_tensor(
examples['example_labels'])), tf.convert_to_tensor(examples[
'example_weights']))) / tf.reduce_sum(ops.convert_to_tensor(
examples['example_weights']))
def regularized_loss(self, examples):
"""Add operations to compute the loss with regularization loss included.
Args:
examples: Examples to compute loss on.
Returns:
An Operation that computes mean (regularized) loss for given set of
examples.
Raises:
ValueError: if examples are not well defined.
"""
self._assertSpecified(
['example_labels', 'example_weights', 'sparse_features',
'dense_features'], examples)
self._assertList(['sparse_features', 'dense_features'], examples)
with tf.name_scope('sdca/regularized_loss'):
logits = self._logits(examples)
# TODO(rohananil): Change loss when supporting linear regression.
return self._l1_loss() + self._l2_loss() + self.unregularized_loss(
examples)

View File

@ -408,7 +408,7 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
// Check that this new set of fetches can be computed from all the
// feeds we have supplied.
TF_RETURN_IF_ERROR(
CheckFetch(inputs, output_names, executors_and_keys->graph, run_state));
CheckFetch(inputs, output_names, executors_and_keys, run_state));
// Send inputs.
Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
@ -510,12 +510,10 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
Status DirectSession::CheckFetch(const NamedTensorList& feeds,
const std::vector<string>& fetches,
const Graph* graph,
const ExecutorsAndKeys* executors_and_keys,
const RunState* run_state) {
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node;
for (Node* n : graph->nodes()) {
name_to_node[n->name()] = n;
}
const Graph* graph = executors_and_keys->graph;
const NameNodeMap* name_to_node = executors_and_keys->name_to_node;
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
@ -523,8 +521,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
mutex_lock l(executor_lock_);
for (const string& feed : run_state->pending_inputs) {
TensorId id(ParseTensorName(feed));
auto it = name_to_node.find(id.first);
if (it == name_to_node.end()) {
auto it = name_to_node->find(id.first);
if (it == name_to_node->end()) {
return errors::NotFound("Feed ", feed, ": not found");
}
pending_feeds.insert(id);
@ -539,8 +537,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
std::vector<const Node*> stack;
for (const string& fetch : fetches) {
TensorId id(ParseTensorName(fetch));
auto it = name_to_node.find(id.first);
if (it == name_to_node.end()) {
auto it = name_to_node->find(id.first);
if (it == name_to_node->end()) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
stack.push_back(it->second);
@ -618,6 +616,10 @@ Status DirectSession::GetOrCreateExecutors(
ek->func_defs = fdefs;
if (run_state_args->is_partial_run) {
ek->graph = run_state_args->graph;
ek->name_to_node = new NameNodeMap;
for (Node* n : run_state_args->graph->nodes()) {
ek->name_to_node->insert({n->name(), n});
}
}
ek->items.reserve(graphs.size());
auto runner = [this](Executor::Args::Closure c) { SchedClosure(c); };

View File

@ -49,6 +49,8 @@ class DirectSession : public Session {
~DirectSession() override;
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
NameNodeMap;
::tensorflow::Status Create(const GraphDef& graph) override;
::tensorflow::Status Extend(const GraphDef& graph) override;
@ -81,13 +83,15 @@ class DirectSession : public Session {
// An ExecutorsAndKeys is created for a given set of feeds/fetches.
// 'func_defs' are the function definition used by all the
// underlying executors. 'graph' is the entire graph being
// executed. Each item in 'items' is the executor for a
// partition of the graph bundled with its dependent library
// runtime. 'input_keys' are the rendezvous keys for the feeds and
// 'output_keys' are rendezvous keys for the fetches.
// executed. 'name_to_node' maps node name to node. We keep 'graph'
// and 'name_to_node' only in the case of partial runs. Each item in
// 'items' is the executor for a partition of the graph bundled with
// its dependent library runtime. 'input_keys' are the rendezvous keys
// for the feeds and 'output_keys' are rendezvous keys for the fetches.
struct ExecutorsAndKeys {
FunctionLibraryDefinition* func_defs = nullptr;
Graph* graph = nullptr;
NameNodeMap* name_to_node = nullptr;
std::vector<PerPartitionExecutorsAndLib> items;
std::unordered_map<string, string> input_keys;
std::unordered_map<string, string> output_keys;
@ -99,6 +103,7 @@ class DirectSession : public Session {
}
delete func_defs;
delete graph;
delete name_to_node;
}
};
@ -177,8 +182,8 @@ class DirectSession : public Session {
// that we have already provided.
::tensorflow::Status CheckFetch(
const std::vector<std::pair<string, Tensor>>& feeds,
const std::vector<string>& fetches, const Graph* graph,
const RunState* run_state);
const std::vector<string>& fetches,
const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
const SessionOptions options_;

View File

@ -82,7 +82,8 @@ def main(_):
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph_def)
writer = tf.train.SummaryWriter(FLAGS.summaries_dir,
sess.graph.as_graph_def(add_shapes=True))
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps

View File

@ -77,6 +77,11 @@ $ sudo easy_install --upgrade six
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.7.0-py3-none-any.whl
```
NOTE: If you are upgrading from a previous installation of TensorFlow < 0.7.1,
you should uninstall the previous TensorFlow *and protobuf* using `pip
uninstall` first to make sure you get a clean installation of the updated
protobuf dependency.
You can now [test your installation](#test-the-tensorflow-installation).
@ -582,6 +587,19 @@ explicitly.
### Pip installation issues
#### Cannot import name 'descriptor'
```python
ImportError: Traceback (most recent call last):
File "/usr/local/lib/python3.4/dist-packages/tensorflow/core/framework/graph_pb2.py", line 6, in <module>
from google.protobuf import descriptor as _descriptor
ImportError: cannot import name 'descriptor'
```
If you the above error when upgrading to a newer version of TensorFlow, try
uninstalling both TensorFlow and protobuf (if installed) and re-installing
TensorFlow (which will also install the correct protobuf dependency).
#### Can't find setup.py
If, during `pip install`, you encounter an error like:

View File

@ -228,3 +228,26 @@ The images below give an illustration for a piece of a real-life graph.
</td>
</tr>
</table>
## Tensor shape information
When the serialized `GraphDef` includes tensor shapes, the graph visualizer
labels edges with tensor dimensions, and edge thickness reflects total tensor
size. To include tensor shapes in the `GraphDef` pass
`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter` when
serializing the graph. The images below show the CIFAR-10 model with tensor
shape information:
<table width="100%;">
<tr>
<td style="width: 100%;">
<img src="../../images/tensor_shapes.png" alt="CIFAR-10 model with tensor shape information" title="CIFAR-10 model with tensor shape information" />
</td>
</tr>
<tr>
<td style="width: 100%;">
CIFAR-10 model with tensor shape information.
</td>
</tr>
</table>

View File

@ -57,6 +57,10 @@ The `SummaryWriter` takes a logdir in its constructor - this logdir is quite
important, it's the directory where all of the events will be written out.
Also, the `SummaryWriter` can optionally take a `GraphDef` in its constructor.
If it receives one, then TensorBoard will visualize your graph as well.
To include tensor shape information in the `GraphDef`, pass
`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter`. This will
give you a much better sense of what flows through the graph: see
[Tensor shape information](../../how_tos/graph_viz/index.md#tensor-shape-information).
Now that you've modified your graph and have a `SummaryWriter`, you're ready to
start running your network! If you want, you could run the merged summary op
@ -102,7 +106,8 @@ with tf.name_scope("test") as scope:
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def)
writer = tf.train.SummaryWriter("/tmp/mnist_logs",
sess.graph.as_graph_def(add_shapes=True))
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps

View File

@ -136,7 +136,7 @@ def evaluate():
# Build the summary operation based on the TF collection of Summaries.
summary_op = tf.merge_all_summaries()
graph_def = tf.get_default_graph().as_graph_def()
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
graph_def=graph_def)

View File

@ -239,8 +239,9 @@ def train():
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
graph_def = sess.graph.as_graph_def(add_shapes=True)
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
graph_def=sess.graph_def)
graph_def=graph_def)
for step in xrange(FLAGS.max_steps):
start_time = time.time()

View File

@ -93,8 +93,9 @@ def train():
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
graph_def = sess.graph.as_graph_def(add_shapes=True)
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
graph_def=sess.graph_def)
graph_def=graph_def)
for step in xrange(FLAGS.max_steps):
start_time = time.time()