Add correct dependencies to sdca ops to fix build breakage.
Change: 115408162
This commit is contained in:
parent
185cff7f41
commit
94a992cfc3
@ -13,6 +13,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
],
|
||||
)
|
||||
|
@ -21,4 +21,5 @@ from __future__ import print_function
|
||||
|
||||
# Add projects here, they will show up under tf.contrib.
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import util
|
||||
|
85
tensorflow/contrib/linear_optimizer/BUILD
Normal file
85
tensorflow/contrib/linear_optimizer/BUILD
Normal file
@ -0,0 +1,85 @@
|
||||
# Description:
|
||||
# Contains ops to train linear models on top of TensorFlow.
|
||||
# APIs here are meant to evolve over time.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
cc_library(
|
||||
name = "sdca_kernel",
|
||||
srcs = ["kernels/sdca_ops.cc"],
|
||||
hdrs = ["kernels/logistic-loss.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sdca_ops",
|
||||
srcs = ["ops/sdca_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["sdca_ops"],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_sdca_ops_py",
|
||||
out = "gen_sdca_ops.py",
|
||||
deps = [":sdca_ops"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sdca_ops_py",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"python/ops/sdca_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_sdca_ops_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sdca_ops_test",
|
||||
srcs = ["python/kernel_tests/sdca_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":sdca_kernel",
|
||||
":sdca_ops",
|
||||
":sdca_ops_py",
|
||||
"//third_party/py/tensorflow",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:protos_all_py_pb2",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
21
tensorflow/contrib/linear_optimizer/__init__.py
Normal file
21
tensorflow/contrib/linear_optimizer/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ops for training linear models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import *
|
103
tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
Normal file
103
tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
Normal file
@ -0,0 +1,103 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace tensorflow {
|
||||
struct logistic_loss {
|
||||
// Use an approximate step that is guaranteed to decrease the dual loss.
|
||||
// Derivation of this is available in Page 14 Eq 16 of
|
||||
// http://arxiv.org/pdf/1211.2717v1.pdf
|
||||
inline static double ComputeUpdatedDual(const double label,
|
||||
const double example_weight,
|
||||
const double current_dual,
|
||||
const double wx,
|
||||
const double weighted_example_norm,
|
||||
const double primal_loss,
|
||||
const double dual_loss) {
|
||||
const double ywx = label * wx;
|
||||
// To avoid overflow, we compute derivative of logistic loss with respect to
|
||||
// log-odds as follows.
|
||||
double inverse_exp_term = 0;
|
||||
if (ywx > 0) {
|
||||
const double exp_minus_ywx = exp(-ywx);
|
||||
inverse_exp_term = exp_minus_ywx / (1 + exp_minus_ywx);
|
||||
} else {
|
||||
inverse_exp_term = 1 / (1 + exp(ywx));
|
||||
}
|
||||
// f(a) = sup (a*x - f(x)) then a = f'(x), where a is the aproximate dual.
|
||||
const double approximate_dual = inverse_exp_term * label;
|
||||
const double delta_dual = approximate_dual - current_dual;
|
||||
// Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
|
||||
// when log-odds is zero.
|
||||
const double gamma =
|
||||
(wx == 0) ? 0.25 : (1 - 2 * inverse_exp_term) / (2 * ywx);
|
||||
const double wx_dual = wx * current_dual * example_weight;
|
||||
const double delta_dual_squared = delta_dual * delta_dual;
|
||||
const double smooth_delta_dual_squared = delta_dual_squared * gamma * 0.5;
|
||||
double multiplier =
|
||||
(primal_loss + dual_loss + wx_dual + smooth_delta_dual_squared) /
|
||||
std::max(1.0,
|
||||
delta_dual_squared *
|
||||
(gamma +
|
||||
weighted_example_norm * example_weight * example_weight));
|
||||
// Multiplier must be in the range [0, 1].
|
||||
multiplier = std::max(std::min(1.0, multiplier), 0.0);
|
||||
return current_dual + delta_dual * multiplier;
|
||||
}
|
||||
|
||||
// Dual of logisitic loss function.
|
||||
// https://en.wikipedia.org/wiki/Convex_conjugate
|
||||
inline static double ComputeDualLoss(const double current_dual,
|
||||
const double example_label,
|
||||
const double example_weight) {
|
||||
// Dual of the logistic loss function is
|
||||
// ay * log(ay) + (1-ay) * log (1-ay), where a is the dual variable.
|
||||
const double ay = current_dual * example_label;
|
||||
const double log_ay = (ay > 0) ? log(ay) : 0;
|
||||
const double one_minus_ay = 1 - ay;
|
||||
const double log_one_minus_ay = (one_minus_ay > 0) ? log(one_minus_ay) : 0;
|
||||
return ((ay * log_ay) + (one_minus_ay * log_one_minus_ay)) * example_weight;
|
||||
}
|
||||
|
||||
// Logistic loss for binary classification.
|
||||
// https://en.wikipedia.org/wiki/Loss_functions_for_classification
|
||||
inline static double ComputePrimalLoss(const double wx,
|
||||
const double example_label,
|
||||
const double example_weight) {
|
||||
// Logistic loss:
|
||||
// log(1 + e^(-ywx))
|
||||
// log(e^0 + e^(-ywx))
|
||||
// a + log(e^(0-a) + e^(-ywx - a)), where a is max(0, -ywx)
|
||||
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
|
||||
const double y_wx = example_label * wx;
|
||||
if (y_wx > 0) {
|
||||
// 0 + log(e^(0) + e^(-ywx - 0))
|
||||
// log(1 + e^(-ywx))
|
||||
return log(1 + exp(-y_wx)) * example_weight;
|
||||
}
|
||||
// -ywx + log(e^(ywx) + e^(-ywx + ywx))
|
||||
// log(e^(ywx) + e^(0)) - ywx
|
||||
// log(1 + e^(ywx)) - ywx
|
||||
return (log(1 + exp(y_wx)) - y_wx) * example_weight;
|
||||
}
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_LOGISTIC_LOSS_H_
|
458
tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
Normal file
458
tensorflow/contrib/linear_optimizer/kernels/sdca_ops.cc
Normal file
@ -0,0 +1,458 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/sdca_ops.cc.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/util/sparse/sparse_tensor.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A feature group of a single example by this struct.
|
||||
struct PerExampleSparseIndicesWeights {
|
||||
// N X 2 matrix with (example_id, feature_indices).
|
||||
tensorflow::TTypes<const int64>::UnalignedMatrix indices;
|
||||
// N X 1 vector with feature weights.
|
||||
tensorflow::TTypes<float>::UnalignedVec values;
|
||||
// sum squared norm of the features.
|
||||
double norm;
|
||||
};
|
||||
|
||||
struct Regularizations {
|
||||
float symmetric_l1 = 0;
|
||||
float symmetric_l2 = 0;
|
||||
};
|
||||
|
||||
struct RegularizationLoss {
|
||||
double l1_loss = 0;
|
||||
double l2_loss = 0;
|
||||
};
|
||||
|
||||
struct PerExampleData {
|
||||
double wx = 0;
|
||||
double norm = 0;
|
||||
};
|
||||
|
||||
// Tensor vector of floats which holds the weights.
|
||||
using Weights = TTypes<float>::Vec;
|
||||
|
||||
// Weights associated with feature group, such that size of WeightsByIndex is
|
||||
// the number of feature groups.
|
||||
using WeightsByIndex = std::vector<Weights>;
|
||||
|
||||
// SparseExamples represent sparse feature groups of each example.
|
||||
using SparseExamples =
|
||||
vector<std::unique_ptr<const PerExampleSparseIndicesWeights>>;
|
||||
|
||||
// SparseExamples associated with each sparse feature group.
|
||||
using SparseExamplesByIndex = vector<SparseExamples>;
|
||||
|
||||
using DenseFeaturesByIndex = vector<tensorflow::TTypes<const float>::Vec>;
|
||||
|
||||
// Compute the shrinkage factor for proximal sdca.
|
||||
inline double ShrinkageFactor(const Regularizations& regularizations) {
|
||||
return regularizations.symmetric_l1 / regularizations.symmetric_l2;
|
||||
}
|
||||
|
||||
// Proximal SDCA shrinking for L1 regularization.
|
||||
inline double Shrink(const double weight, const double shrink_by) {
|
||||
const double shrink_weight = std::max(std::abs(weight) - shrink_by, 0.0);
|
||||
if (shrink_weight > 0.0) {
|
||||
return std::copysign(shrink_weight, weight);
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Compute L1 and L2 regularization loss.
|
||||
inline RegularizationLoss ComputeRegularizationLoss(
|
||||
const WeightsByIndex& sparse_weights_by_index,
|
||||
const WeightsByIndex& dense_weights_by_index,
|
||||
const Regularizations& regularizations) {
|
||||
RegularizationLoss result;
|
||||
|
||||
const double shrink_by = ShrinkageFactor(regularizations);
|
||||
auto accumulate_regularization_loss = [&](const double w) {
|
||||
const double sw = std::abs(Shrink(w, shrink_by));
|
||||
result.l1_loss += sw;
|
||||
result.l2_loss += sw * sw;
|
||||
};
|
||||
|
||||
for (auto& sparse_weights : sparse_weights_by_index) {
|
||||
for (size_t i = 0; i < sparse_weights.size(); ++i) {
|
||||
accumulate_regularization_loss(sparse_weights(i));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& dense_weights : dense_weights_by_index) {
|
||||
accumulate_regularization_loss(dense_weights(0));
|
||||
}
|
||||
|
||||
result.l1_loss *= regularizations.symmetric_l1;
|
||||
result.l2_loss *= regularizations.symmetric_l2;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute PerExampleData which contains the logits, and weighted example norm
|
||||
// for a given example_id. Norm is weighted by 1/(lambda*N).
|
||||
inline PerExampleData ComputeWxAndWeightedExampleNorm(
|
||||
const int64 example_id, const WeightsByIndex& sparse_weights_by_index,
|
||||
const SparseExamplesByIndex& sparse_examples_by_index,
|
||||
const WeightsByIndex& dense_weights_by_index,
|
||||
const DenseFeaturesByIndex& dense_features,
|
||||
const Regularizations& regularizations) {
|
||||
PerExampleData result;
|
||||
const double shrink_by = ShrinkageFactor(regularizations);
|
||||
for (size_t i = 0; i < sparse_examples_by_index.size(); ++i) {
|
||||
const SparseExamples& sparse_indices_values = sparse_examples_by_index[i];
|
||||
const Weights sparse_weights = sparse_weights_by_index[i];
|
||||
if (sparse_indices_values[example_id]) {
|
||||
const auto indices = sparse_indices_values[example_id]->indices;
|
||||
const auto values = sparse_indices_values[example_id]->values;
|
||||
for (size_t dim = 0; dim < indices.dimension(0); ++dim) {
|
||||
result.wx +=
|
||||
Shrink(sparse_weights(indices(dim, 1)), shrink_by) * values(dim);
|
||||
}
|
||||
result.norm += sparse_indices_values[example_id]->norm;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < dense_features.size(); ++i) {
|
||||
const auto dense_values = dense_features[i];
|
||||
const Weights dense_weights = dense_weights_by_index[i];
|
||||
result.wx += Shrink(dense_weights(0), shrink_by) * dense_values(example_id);
|
||||
result.norm += dense_values(example_id) * dense_values(example_id);
|
||||
}
|
||||
result.norm /= regularizations.symmetric_l2;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Apply L1 regularization on the weights,
|
||||
void ShrinkWeights(const Regularizations& regularizations,
|
||||
WeightsByIndex* const sparse_weights_by_index,
|
||||
WeightsByIndex* const dense_weights_by_index) {
|
||||
const double shrink_by = ShrinkageFactor(regularizations);
|
||||
for (auto& sparse_weights : *sparse_weights_by_index) {
|
||||
for (size_t i = 0; i < sparse_weights.size(); ++i) {
|
||||
sparse_weights(i) = Shrink(sparse_weights(i), shrink_by);
|
||||
}
|
||||
}
|
||||
for (auto& dense_weights : *dense_weights_by_index) {
|
||||
dense_weights(0) = Shrink(dense_weights(0), shrink_by);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateWeights(const int64 example_id,
|
||||
const SparseExamplesByIndex& sparse_examples_by_index,
|
||||
const DenseFeaturesByIndex& dense_features,
|
||||
const double bounded_dual_delta,
|
||||
const double l2_regularization,
|
||||
WeightsByIndex* const sparse_weights_by_index,
|
||||
WeightsByIndex* const dense_weights_by_index) {
|
||||
for (size_t i = 0; i < sparse_examples_by_index.size(); ++i) {
|
||||
const SparseExamples& sparse_indices_values = sparse_examples_by_index[i];
|
||||
Weights sparse_weights = (*sparse_weights_by_index)[i];
|
||||
if (sparse_indices_values[example_id]) {
|
||||
const auto indices = sparse_indices_values[example_id]->indices;
|
||||
const auto values = sparse_indices_values[example_id]->values;
|
||||
for (size_t dim = 0; dim < indices.dimension(0); ++dim) {
|
||||
// TODO(rohananil): Atomic updates provide better convergence guarantees
|
||||
// However, casting float to atomic<float> is UB. We may consider
|
||||
// sharded set of locks, or bring primal-dual relationship to consistent
|
||||
// state after several epochs.
|
||||
sparse_weights(indices(dim, 1)) +=
|
||||
bounded_dual_delta * values(dim) / l2_regularization;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < dense_features.size(); ++i) {
|
||||
const auto dense_values = dense_features[i];
|
||||
Weights dense_weights = (*dense_weights_by_index)[i];
|
||||
// TODO(rohananil): Atomic updates provide better convergence gaurantees
|
||||
// However, casting float to atomic<float> is UB. We may consider
|
||||
// sharded set of locks, or bring primal-dual relationship to consistent
|
||||
// state after several epochs.
|
||||
dense_weights(0) +=
|
||||
bounded_dual_delta * dense_values(example_id) / l2_regularization;
|
||||
}
|
||||
}
|
||||
|
||||
// Atomically add a double to a std::atomic<double>.
|
||||
inline void AtomicAdd(const double value, std::atomic<double>* const dst) {
|
||||
// We use a strong version of compare-exchange, as weak version can spuriously
|
||||
// fail.
|
||||
for (double c = dst->load(); !dst->compare_exchange_strong(c, c + value);) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class SdcaSolver : public OpKernel {
|
||||
public:
|
||||
explicit SdcaSolver(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string loss_type;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("LossType", &loss_type));
|
||||
if (loss_type == "logistic_loss") {
|
||||
compute_dual_loss_ = logistic_loss::ComputeDualLoss;
|
||||
compute_primal_loss_ = logistic_loss::ComputePrimalLoss;
|
||||
compute_dual_update_ = logistic_loss::ComputeUpdatedDual;
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("NumSparseFeatures", &num_sparse_features_));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("NumDenseFeatures", &num_dense_features_));
|
||||
OP_REQUIRES(
|
||||
context, num_sparse_features_ + num_dense_features_ > 0,
|
||||
errors::InvalidArgument("Requires at least one feature to train."));
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("L1", ®ularizations_.symmetric_l1));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("L2", ®ularizations_.symmetric_l2));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("DualityGapThreshold",
|
||||
&duality_gap_threshold_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* example_weights_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("example_weights", &example_weights_t));
|
||||
const auto example_weights = example_weights_t->vec<float>();
|
||||
|
||||
Tensor primal_loss_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->mutable_input("primal_loss", &primal_loss_t,
|
||||
/*lock_held=*/true));
|
||||
auto primal_loss = primal_loss_t.scalar<double>();
|
||||
|
||||
const int64 num_examples = example_weights.size();
|
||||
|
||||
Eigen::Tensor<float, 0, Eigen::RowMajor> example_weights_sum;
|
||||
example_weights_sum.device(context->eigen_cpu_device()) =
|
||||
example_weights.sum();
|
||||
const float weighted_examples = example_weights_sum();
|
||||
|
||||
OP_REQUIRES(context, weighted_examples > 0.0,
|
||||
errors::InvalidArgument("No weighted examples in ",
|
||||
num_examples, " training examples"));
|
||||
// We scale it up by weighted examples.
|
||||
regularizations_.symmetric_l1 =
|
||||
regularizations_.symmetric_l1 * weighted_examples;
|
||||
regularizations_.symmetric_l2 =
|
||||
std::max(regularizations_.symmetric_l2 * weighted_examples, 1.0f);
|
||||
|
||||
OpInputList sparse_features_indices_inputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("sparse_features_indices",
|
||||
&sparse_features_indices_inputs));
|
||||
OpInputList sparse_features_values;
|
||||
OP_REQUIRES_OK(context, context->input_list("sparse_features_values",
|
||||
&sparse_features_values));
|
||||
SparseExamplesByIndex sparse_examples_by_index(num_sparse_features_);
|
||||
// Goes through the entire training set once, so that we create per example
|
||||
// structures. We use it downstream for randomizating order of the examples.
|
||||
auto do_parse = [&](const int64 begin, const int64 end) {
|
||||
// We set the order as [0, 1], which specifies that its row-major
|
||||
// increasing. This means first column has ids which is
|
||||
// lexicographically increasing.
|
||||
static const int64 kIndicesDims = 2;
|
||||
gtl::InlinedVector<int64, 8> order(kIndicesDims);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
for (size_t i = begin; i < end; ++i) {
|
||||
OP_REQUIRES(context, sparse_features_indices_inputs[i].shape().dims() ==
|
||||
kIndicesDims,
|
||||
errors::InvalidArgument(
|
||||
"Indices should have exactly 2 dimensions"));
|
||||
tensorflow::sparse::SparseTensor st(
|
||||
sparse_features_indices_inputs[i], sparse_features_values[i],
|
||||
sparse_features_indices_inputs[i].shape(), order);
|
||||
sparse_examples_by_index[i] = SparseExamples(num_examples);
|
||||
for (const auto& example_group : st.group({0})) {
|
||||
const int64 example_id = example_group.indices()(0, 0);
|
||||
const Eigen::Tensor<float, 0, Eigen::RowMajor> norm =
|
||||
example_group.values<float>().square().sum();
|
||||
sparse_examples_by_index[i][example_id].reset(
|
||||
new PerExampleSparseIndicesWeights{example_group.indices(),
|
||||
example_group.values<float>(),
|
||||
norm()});
|
||||
}
|
||||
}
|
||||
};
|
||||
{
|
||||
const DeviceBase::CpuWorkerThreads* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads();
|
||||
// For each column, the cost of parsing it O(num_examples). We use
|
||||
// num_examples here, as empircally Shard() creates the right amount of
|
||||
// threads based on the problem size.
|
||||
// TODO(rohananil): Tune this as a function of dataset size.
|
||||
const int64 kCostPerUnit = num_examples;
|
||||
Shard(worker_threads->num_threads, worker_threads->workers,
|
||||
num_sparse_features_, kCostPerUnit, do_parse);
|
||||
}
|
||||
|
||||
OpInputList dense_features_inputs;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input_list("dense_features", &dense_features_inputs));
|
||||
|
||||
DenseFeaturesByIndex dense_features;
|
||||
for (auto& dense_features_input : dense_features_inputs) {
|
||||
dense_features.emplace_back(dense_features_input.vec<float>());
|
||||
}
|
||||
|
||||
const Tensor* example_labels_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("example_labels", &example_labels_t));
|
||||
const auto example_labels = example_labels_t->vec<float>();
|
||||
|
||||
Tensor dual_variables_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->mutable_input("dual_variables", &dual_variables_t,
|
||||
/*lock_held=*/true));
|
||||
auto dual_variables = dual_variables_t.vec<float>();
|
||||
|
||||
OpMutableInputList sparse_weights_by_index_inputs;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->mutable_input_list("sparse_weights_by_index",
|
||||
&sparse_weights_by_index_inputs));
|
||||
WeightsByIndex sparse_weights_by_index;
|
||||
for (size_t i = 0; i < sparse_weights_by_index_inputs.size(); ++i) {
|
||||
sparse_weights_by_index.emplace_back(
|
||||
sparse_weights_by_index_inputs.at(i, /*lock_held=*/true)
|
||||
.flat<float>());
|
||||
}
|
||||
|
||||
OpMutableInputList dense_weights_by_index_inputs;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->mutable_input_list("dense_weights_by_index",
|
||||
&dense_weights_by_index_inputs));
|
||||
WeightsByIndex dense_weights_by_index;
|
||||
for (size_t i = 0; i < dense_weights_by_index_inputs.size(); ++i) {
|
||||
dense_weights_by_index.emplace_back(
|
||||
dense_weights_by_index_inputs.at(i, /*lock_held=*/true)
|
||||
.flat<float>());
|
||||
}
|
||||
|
||||
vector<int64> example_ids(num_examples);
|
||||
std::iota(example_ids.begin(), example_ids.end(), 0);
|
||||
std::random_device random_device;
|
||||
std::mt19937 random_generator(random_device());
|
||||
std::atomic<double> total_approx_duality_gap(
|
||||
std::numeric_limits<double>::max());
|
||||
std::atomic<double> total_primal_loss(0);
|
||||
// Break when duality gap |P(w) - D(alpha)| is less than
|
||||
// duality_gap_threshold_
|
||||
while ((total_approx_duality_gap / weighted_examples) >
|
||||
duality_gap_threshold_) {
|
||||
// Reset and add everything.
|
||||
total_approx_duality_gap = 0;
|
||||
total_primal_loss = 0;
|
||||
std::shuffle(example_ids.begin(), example_ids.end(), random_generator);
|
||||
auto do_update = [&](const int64 begin, const int64 end) {
|
||||
double approx_duality_gap = 0;
|
||||
for (int64 offset = begin; offset < end; ++offset) {
|
||||
// Get example id, label, and weight.
|
||||
const int64 example_id = example_ids[offset];
|
||||
OP_REQUIRES(context, !(example_labels(example_id) > 0 &&
|
||||
example_labels(example_id) < 1),
|
||||
errors::InvalidArgument(
|
||||
"Fractional labels not supported right now. "
|
||||
"Found example with label: ",
|
||||
example_labels(example_id)));
|
||||
const float example_label = example_labels(example_id) == 0 ? -1 : 1;
|
||||
const double current_dual = dual_variables(example_id);
|
||||
const double example_weight = example_weights(example_id);
|
||||
|
||||
// Compute wx, example norm weighted by regularization, dual loss,
|
||||
// primal loss.
|
||||
const PerExampleData per_example_data =
|
||||
ComputeWxAndWeightedExampleNorm(
|
||||
example_id, sparse_weights_by_index, sparse_examples_by_index,
|
||||
dense_weights_by_index, dense_features, regularizations_);
|
||||
|
||||
const double dual_loss = compute_dual_loss_(
|
||||
current_dual, example_label, example_weight);
|
||||
const double primal_loss = compute_primal_loss_(
|
||||
per_example_data.wx, example_label, example_weight);
|
||||
approx_duality_gap += dual_loss + primal_loss;
|
||||
|
||||
// Update dual variable.
|
||||
dual_variables(example_id) = compute_dual_update_(
|
||||
example_label, example_weight, current_dual, per_example_data.wx,
|
||||
per_example_data.norm, primal_loss, dual_loss);
|
||||
|
||||
// Compute new weights.
|
||||
const double bounded_dual_delta =
|
||||
(dual_variables(example_id) - current_dual) * example_weight;
|
||||
UpdateWeights(example_id, sparse_examples_by_index, dense_features,
|
||||
bounded_dual_delta, regularizations_.symmetric_l2,
|
||||
&sparse_weights_by_index, &dense_weights_by_index);
|
||||
|
||||
AtomicAdd(approx_duality_gap, &total_approx_duality_gap);
|
||||
AtomicAdd(primal_loss, &total_primal_loss);
|
||||
}
|
||||
// TODO(rohananil): We may in the future want to make the primal-dual
|
||||
// relationship consistent as our current updates are not transactional.
|
||||
};
|
||||
const DeviceBase::CpuWorkerThreads* const worker_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads();
|
||||
const int64 kCostPerUnit =
|
||||
100000 * (num_sparse_features_ + num_dense_features_);
|
||||
Shard(worker_threads->num_threads, worker_threads->workers, num_examples,
|
||||
kCostPerUnit, do_update);
|
||||
const RegularizationLoss regularization_loss = ComputeRegularizationLoss(
|
||||
sparse_weights_by_index, dense_weights_by_index, regularizations_);
|
||||
total_approx_duality_gap.store(total_approx_duality_gap.load() +
|
||||
regularization_loss.l1_loss +
|
||||
regularization_loss.l2_loss);
|
||||
primal_loss() = (total_primal_loss.load() + regularization_loss.l1_loss +
|
||||
regularization_loss.l2_loss) /
|
||||
weighted_examples;
|
||||
}
|
||||
ShrinkWeights(regularizations_, &sparse_weights_by_index,
|
||||
&dense_weights_by_index);
|
||||
}
|
||||
|
||||
private:
|
||||
std::function<decltype(logistic_loss::ComputeDualLoss)> compute_dual_loss_;
|
||||
std::function<decltype(logistic_loss::ComputePrimalLoss)>
|
||||
compute_primal_loss_;
|
||||
std::function<decltype(logistic_loss::ComputeUpdatedDual)>
|
||||
compute_dual_update_;
|
||||
int64 num_sparse_features_;
|
||||
int64 num_dense_features_;
|
||||
Regularizations regularizations_;
|
||||
float duality_gap_threshold_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("SdcaSolver").Device(DEVICE_CPU), SdcaSolver);
|
||||
|
||||
} // namespace tensorflow
|
74
tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
Normal file
74
tensorflow/contrib/linear_optimizer/ops/sdca_ops.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("SdcaSolver")
|
||||
.Attr("LossType: {'logistic_loss'}")
|
||||
.Attr("NumSparseFeatures: int >= 0")
|
||||
.Attr("NumDenseFeatures: int >= 0")
|
||||
.Attr("L1: float >= 0")
|
||||
.Attr("L2: float >= 0")
|
||||
.Attr("DualityGapThreshold: float = 0.01")
|
||||
.Input("sparse_features_indices: NumSparseFeatures * int64")
|
||||
.Input("sparse_features_values: NumSparseFeatures * float")
|
||||
.Input("dense_features: NumDenseFeatures * float")
|
||||
.Input("example_weights: float")
|
||||
.Input("example_labels: float")
|
||||
.Input("dual_variables: Ref(float)")
|
||||
.Input(
|
||||
"sparse_weights_by_index: Ref(NumSparseFeatures * "
|
||||
"float)")
|
||||
.Input(
|
||||
"dense_weights_by_index: Ref(NumDenseFeatures * "
|
||||
"float)")
|
||||
.Input("primal_loss: Ref(double)")
|
||||
.Doc(R"doc(
|
||||
Stochastic Dual Coordinate Ascent (SDCA) optimizer for linear models with
|
||||
L1 + L2 regularization. As global optimization objective is strongly-convex, the
|
||||
optimizer optimizes the dual objective at each step. The optimizer applies each
|
||||
update one example at a time. Examples are sampled uniformly, and the optimizer
|
||||
is learning rate free and enjoys linear convergence rate.
|
||||
|
||||
Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.
|
||||
2012arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf
|
||||
|
||||
LossType: Type of the primal loss. Only logistic_loss is supported.
|
||||
NumSparseFeatures: Number of sparse features to train on.
|
||||
NumDenseFeatures: Number of dense features to train on.
|
||||
L1: Per example symmetric l1 regularization strength.
|
||||
L2: Per example symmetric l2 regularization strength.
|
||||
DualityGapThreshold: Gap threshold at which we should stop training.
|
||||
sparse_features_indices: a list of matrices with two columns that contain
|
||||
example_indices, and feature_indices.
|
||||
sparse_features_values: a list of vectors which contains feature value
|
||||
associated with each feature index.
|
||||
dense_features: a list of vectors which contains the dense feature values.
|
||||
example_weights: a vector which contains the example weight associated with
|
||||
each example.
|
||||
example_labels: a vector which contains the example label/target asscociated
|
||||
with each example.
|
||||
dual_variables: a vector which contains the dual variable asscociated with each
|
||||
example.
|
||||
sparse_weights_by_index: a list of vectors where each value is the weight
|
||||
associated with a feature index.
|
||||
dense_weights_by_index: a list of vectors where the value is the weight
|
||||
associated with the dense feature.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,269 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SdcaModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
SdcaModel = tf.contrib.linear_optimizer.SdcaModel
|
||||
|
||||
|
||||
def make_example_proto(feature_dict, target):
|
||||
e = tf.train.Example()
|
||||
features = e.features
|
||||
|
||||
features.feature['target'].float_list.value.append(target)
|
||||
|
||||
for key, values in feature_dict.iteritems():
|
||||
features.feature[key + '_indices'].int64_list.value.extend(values)
|
||||
features.feature[key + '_values'].float_list.value.extend([1.0] *
|
||||
len(values))
|
||||
|
||||
return e
|
||||
|
||||
|
||||
def make_example_dict(example_protos, example_weights):
|
||||
|
||||
def parse_examples(example_protos):
|
||||
features = {
|
||||
'target': tf.FixedLenFeature(shape=[1],
|
||||
dtype=tf.float32,
|
||||
default_value=0),
|
||||
'age_indices': tf.VarLenFeature(dtype=tf.int64),
|
||||
'age_values': tf.VarLenFeature(dtype=tf.float32),
|
||||
'gender_indices': tf.VarLenFeature(dtype=tf.int64),
|
||||
'gender_values': tf.VarLenFeature(dtype=tf.float32)
|
||||
}
|
||||
return tf.parse_example(
|
||||
[e.SerializeToString() for e in example_protos], features)
|
||||
|
||||
# TODO(rohananil): This converts two sparse tensors, into one sparse feature
|
||||
# tensor. Use the tf core op once its merged in.
|
||||
def sf_from_st(ids, weights):
|
||||
example_indices, _ = tf.split(1, 2, ids.indices)
|
||||
feature_indices = tf.expand_dims(ids.values, 1)
|
||||
indices = tf.concat(1, [example_indices, feature_indices])
|
||||
return tf.SparseTensor(indices, weights.values, ids.shape)
|
||||
|
||||
parsed = parse_examples(example_protos)
|
||||
return dict(sparse_features=[
|
||||
sf_from_st(parsed['age_indices'], parsed['age_values']), sf_from_st(
|
||||
parsed[
|
||||
'gender_indices'], parsed['gender_values'])
|
||||
],
|
||||
dense_features=[],
|
||||
example_weights=example_weights,
|
||||
example_labels=tf.reshape(parsed['target'], [-1]))
|
||||
|
||||
|
||||
def make_variable_dict(examples_dict, max_age, max_gender):
|
||||
# TODO(dbaylor): Figure out how to derive max_age & max_gender from
|
||||
# examples_dict.
|
||||
age_weights = tf.Variable(tf.zeros([max_age + 1], dtype=tf.float32))
|
||||
gender_weights = tf.Variable(tf.zeros([max_gender + 1], dtype=tf.float32))
|
||||
dual_variables = tf.Variable(tf.zeros_like(examples_dict['example_labels'],
|
||||
dtype=tf.float32))
|
||||
training_log_loss = tf.Variable(tf.zeros([], dtype=tf.float64))
|
||||
return dict(sparse_features_weights=[age_weights, gender_weights],
|
||||
dense_features_weights=[],
|
||||
dual=dual_variables,
|
||||
training_log_loss=training_log_loss)
|
||||
|
||||
|
||||
class SdcaOptimizerTest(TensorFlowTestCase):
|
||||
|
||||
def testSimple(self):
|
||||
# Setup test data
|
||||
example_protos = [
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
]
|
||||
example_weights = [1.0, 1.0]
|
||||
with self.test_session(use_gpu=False):
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(examples, 1, 1)
|
||||
options = dict(symmetric_l2_regularization=0.5,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss',
|
||||
prior=0.0)
|
||||
tf.initialize_all_variables().run()
|
||||
lr = SdcaModel(examples, variables, options)
|
||||
unregularized_loss = lr.unregularized_loss(examples)
|
||||
loss = lr.regularized_loss(examples)
|
||||
prediction = lr.predictions(examples)
|
||||
self.assertAllClose(0.693147, unregularized_loss.eval())
|
||||
self.assertAllClose(0.693147, loss.eval())
|
||||
lr.minimize().run()
|
||||
self.assertAllClose(0.395226, unregularized_loss.eval())
|
||||
self.assertAllClose(0.657446, loss.eval())
|
||||
predicted_labels = tf.cast(
|
||||
tf.greater_equal(prediction,
|
||||
tf.ones_like(prediction) * 0.5), tf.float32)
|
||||
self.assertAllEqual([0, 1], predicted_labels.eval())
|
||||
|
||||
def testSomeUnweightedExamples(self):
|
||||
# Setup test data with 4 examples, but should produce the same
|
||||
# results as testSimple.
|
||||
example_protos = [
|
||||
# Will be used.
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
# Will be ignored.
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [0]}, 0),
|
||||
# Will be used.
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
# Will be ignored.
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [0]}, 1),
|
||||
]
|
||||
example_weights = [1.0, 0.0, 1.0, 0.0]
|
||||
with self.test_session(use_gpu=False):
|
||||
# Only use examples 0 and 2
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(examples, 1, 1)
|
||||
options = dict(symmetric_l2_regularization=0.25,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss')
|
||||
tf.initialize_all_variables().run()
|
||||
lr = SdcaModel(examples, variables, options)
|
||||
unregularized_loss = lr.unregularized_loss(examples)
|
||||
loss = lr.regularized_loss(examples)
|
||||
prediction = lr.predictions(examples)
|
||||
lr.minimize().run()
|
||||
self.assertAllClose(0.395226, unregularized_loss.eval())
|
||||
self.assertAllClose(0.526336, loss.eval())
|
||||
predicted_labels = tf.cast(
|
||||
tf.greater_equal(prediction,
|
||||
tf.ones_like(prediction) * 0.5), tf.float32)
|
||||
self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())
|
||||
|
||||
def testNoWeightedExamples(self):
|
||||
# Setup test data with 1 positive, and 1 negative example.
|
||||
example_protos = [
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
]
|
||||
# Zeroed out example weights.
|
||||
example_weights = [0.0, 0.0]
|
||||
with self.test_session(use_gpu=False):
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(examples, 1, 1)
|
||||
options = dict(symmetric_l2_regularization=0.5,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss')
|
||||
tf.initialize_all_variables().run()
|
||||
lr = SdcaModel(examples, variables, options)
|
||||
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
|
||||
with self.assertRaisesOpError(
|
||||
'No weighted examples in 2 training examples'):
|
||||
lr.minimize().run()
|
||||
self.assertAllClose([0.5, 0.5], lr.predictions(examples).eval())
|
||||
|
||||
def testImbalanced(self):
|
||||
# Setup test data with 1 positive, and 3 negative examples.
|
||||
example_protos = [
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [2],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [3],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
]
|
||||
example_weights = [1.0, 1.0, 1.0, 1.0]
|
||||
with self.test_session(use_gpu=False):
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(examples, 3, 1)
|
||||
options = dict(symmetric_l2_regularization=0.25,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss',
|
||||
prior=-1.09861)
|
||||
tf.initialize_all_variables().run()
|
||||
lr = SdcaModel(examples, variables, options)
|
||||
unregularized_loss = lr.unregularized_loss(examples)
|
||||
loss = lr.regularized_loss(examples)
|
||||
prediction = lr.predictions(examples)
|
||||
lr.minimize().run()
|
||||
self.assertAllClose(0.331710,
|
||||
unregularized_loss.eval(),
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
self.assertAllClose(0.591295, loss.eval(), rtol=1e-2, atol=1e-2)
|
||||
predicted_labels = tf.cast(
|
||||
tf.greater_equal(prediction,
|
||||
tf.ones_like(prediction) * 0.5), tf.float32)
|
||||
self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval())
|
||||
|
||||
def testImbalancedWithExampleWeights(self):
|
||||
# Setup test data with 1 positive, and 3 negative examples.
|
||||
example_protos = [
|
||||
make_example_proto(
|
||||
{'age': [0],
|
||||
'gender': [0]}, 0),
|
||||
make_example_proto(
|
||||
{'age': [1],
|
||||
'gender': [1]}, 1),
|
||||
]
|
||||
example_weights = [3.0, 1.0]
|
||||
with self.test_session(use_gpu=False):
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
variables = make_variable_dict(examples, 1, 1)
|
||||
options = dict(symmetric_l2_regularization=0.25,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='logistic_loss')
|
||||
tf.initialize_all_variables().run()
|
||||
lr = SdcaModel(examples, variables, options)
|
||||
unregularized_loss = lr.unregularized_loss(examples)
|
||||
loss = lr.regularized_loss(examples)
|
||||
prediction = lr.predictions(examples)
|
||||
lr.minimize().run()
|
||||
self.assertAllClose(0.266189,
|
||||
unregularized_loss.eval(),
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
self.assertAllClose(0.571912, loss.eval(), rtol=1e-2, atol=1e-2)
|
||||
predicted_labels = tf.cast(
|
||||
tf.greater_equal(prediction,
|
||||
tf.ones_like(prediction) * 0.5), tf.float32)
|
||||
self.assertAllEqual([0, 1], predicted_labels.eval())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
275
tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
Normal file
275
tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
Normal file
@ -0,0 +1,275 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Proximal stochastic dual coordinate ascent optimizer for linear models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.linear_optimizer.gen_sdca_ops import *
|
||||
|
||||
|
||||
class SdcaModel(object):
|
||||
"""Stochastic dual coordinate ascent solver for linear models.
|
||||
|
||||
This class currently only supports a single machine (multi-threaded)
|
||||
implementation. We expect the data, and weights to fit in a single machine.
|
||||
|
||||
Loss functions supported:
|
||||
* Binary logistic loss
|
||||
|
||||
This class defines an optimizer API to train a linear model.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
# Create a solver with the desired parameters.
|
||||
lr = tf.contrib.linear_optimizer.SdcaModel(examples, variables, options)
|
||||
opt_op = lr.minimize()
|
||||
|
||||
predictions = lr.predictions(examples)
|
||||
# Primal loss + L1 loss + L2 loss.
|
||||
regularized_loss = lr.regularized_loss(examples)
|
||||
# Primal loss only
|
||||
unregularized_loss = lr.unregularized_loss(examples)
|
||||
|
||||
examples: {
|
||||
sparse_features: list of SparseTensors of value type float32.
|
||||
dense_features: list of dense tensors of type float32.
|
||||
example_labels: a tensor of of shape [Num examples]
|
||||
example_weights: a tensor of shape [Num examples]
|
||||
}
|
||||
variables: {
|
||||
sparse_features_weights: list of tensors of shape [vocab size]
|
||||
dense_features_weights: list of tensors of shape [1]
|
||||
dual: tensor of shape [Num examples]
|
||||
}
|
||||
options: {
|
||||
symmetric_l1_regularization: 0.0
|
||||
symmetric_l2_regularization: 1.0
|
||||
loss_type: "logistic_loss"
|
||||
}
|
||||
```
|
||||
|
||||
In the training program you will just have to run the returned Op from
|
||||
minimize().
|
||||
|
||||
```python
|
||||
# Execute opt_op once to perform training, which continues until
|
||||
convergence.
|
||||
The op makes use of duality gap as a certificate for termination. Duality
|
||||
gap is set to 0.01 as default.
|
||||
opt_op.run()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, examples, variables, options):
|
||||
"""Create a new sdca optimizer."""
|
||||
|
||||
if not examples or not variables or not options:
|
||||
raise ValueError('All arguments must be specified.')
|
||||
|
||||
if options['loss_type'] != 'logistic_loss':
|
||||
raise ValueError('Optimizer only supports logistic regression (for now).')
|
||||
|
||||
self._assertSpecified(
|
||||
['example_labels', 'example_weights', 'sparse_features',
|
||||
'dense_features'], examples)
|
||||
self._assertList(['sparse_features', 'dense_features'], examples)
|
||||
|
||||
self._assertSpecified(
|
||||
['sparse_features_weights', 'dense_features_weights',
|
||||
'training_log_loss', 'dual'], variables)
|
||||
self._assertList(
|
||||
['sparse_features_weights', 'dense_features_weights'], variables)
|
||||
|
||||
self._assertSpecified(
|
||||
['loss_type', 'symmetric_l2_regularization',
|
||||
'symmetric_l1_regularization'], options)
|
||||
|
||||
self._examples = examples
|
||||
self._variables = variables
|
||||
self._options = options
|
||||
self._training_log_loss = tf.convert_to_tensor(
|
||||
self._variables['training_log_loss'],
|
||||
as_ref=True)
|
||||
|
||||
def _assertSpecified(self, items, check_in):
|
||||
for x in items:
|
||||
if check_in[x] is None:
|
||||
raise ValueError(check_in[x] + ' must be specified.')
|
||||
|
||||
def _assertList(self, items, check_in):
|
||||
for x in items:
|
||||
if not isinstance(check_in[x], list):
|
||||
raise ValueError(x + ' must be a list.')
|
||||
|
||||
def _l1_loss(self):
|
||||
""""Computes the l1 loss of the model."""
|
||||
with tf.name_scope('l1_loss'):
|
||||
sparse_weights = self._convert_n_to_tensor(self._variables[
|
||||
'sparse_features_weights'])
|
||||
dense_weights = self._convert_n_to_tensor(self._variables[
|
||||
'dense_features_weights'])
|
||||
l1 = self._options['symmetric_l1_regularization']
|
||||
loss = 0
|
||||
for w in sparse_weights:
|
||||
loss += l1 * tf.reduce_sum(tf.abs(w))
|
||||
for w in dense_weights:
|
||||
loss += l1 * tf.reduce_sum(tf.abs(w))
|
||||
return loss
|
||||
|
||||
def _l2_loss(self):
|
||||
""""Computes the l1 loss of the model."""
|
||||
with tf.name_scope('l2_loss'):
|
||||
sparse_weights = self._convert_n_to_tensor(self._variables[
|
||||
'sparse_features_weights'])
|
||||
dense_weights = self._convert_n_to_tensor(self._variables[
|
||||
'dense_features_weights'])
|
||||
l2 = self._options['symmetric_l2_regularization']
|
||||
loss = 0
|
||||
for w in sparse_weights:
|
||||
loss += l2 * tf.reduce_sum(tf.square(w))
|
||||
for w in dense_weights:
|
||||
loss += l2 * tf.reduce_sum(tf.square(w))
|
||||
return loss
|
||||
|
||||
def _logits(self, examples):
|
||||
"""Compute logits for each example."""
|
||||
with tf.name_scope('logits'):
|
||||
sparse_variables = self._convert_n_to_tensor(self._variables[
|
||||
'sparse_features_weights'])
|
||||
logits = 0
|
||||
for st_i, sv in zip(examples['sparse_features'], sparse_variables):
|
||||
ei, fi = tf.split(1, 2, st_i.indices)
|
||||
ei = tf.reshape(ei, [-1])
|
||||
fi = tf.reshape(fi, [-1])
|
||||
fv = tf.reshape(st_i.values, [-1])
|
||||
# TODO(rohananil): This does not work if examples have empty
|
||||
# features.
|
||||
logits += tf.segment_sum(
|
||||
tf.mul(
|
||||
tf.gather(sv, fi), fv), tf.reshape(ei, [-1]))
|
||||
dense_features = self._convert_n_to_tensor(examples['dense_features'])
|
||||
dense_variables = self._convert_n_to_tensor(self._variables[
|
||||
'dense_features_weights'])
|
||||
for i in xrange(len(dense_variables)):
|
||||
logits += dense_features[i] * dense_variables[i]
|
||||
return logits
|
||||
|
||||
def _convert_n_to_tensor(self, input_list, as_ref=False):
|
||||
"""Converts input list to a set of tensors."""
|
||||
return [tf.convert_to_tensor(x, as_ref=as_ref) for x in input_list]
|
||||
|
||||
def predictions(self, examples):
|
||||
"""Add operations to compute predictions by the model.
|
||||
|
||||
Args:
|
||||
examples: Examples to compute prediction on.
|
||||
|
||||
Returns:
|
||||
An Operation that computes the predictions for examples. For logistic
|
||||
loss
|
||||
output is a tensor with sigmoid output.
|
||||
Raises:
|
||||
ValueError: if examples are not well defined.
|
||||
"""
|
||||
self._assertSpecified(
|
||||
['example_weights', 'sparse_features', 'dense_features'], examples)
|
||||
self._assertList(['sparse_features', 'dense_features'], examples)
|
||||
with tf.name_scope('sdca/prediction'):
|
||||
logits = self._logits(examples)
|
||||
# TODO(rohananil): Change prediction when supporting linear
|
||||
# regression.
|
||||
return tf.sigmoid(logits)
|
||||
|
||||
def minimize(self):
|
||||
"""Add operations to train a linear model by minimizing the loss function.
|
||||
|
||||
Returns:
|
||||
An Operation that updates the variables passed in the constructor.
|
||||
"""
|
||||
with tf.name_scope('sdca/minimize'):
|
||||
sparse_features_indices = []
|
||||
sparse_features_weights = []
|
||||
for sf in self._examples['sparse_features']:
|
||||
sparse_features_indices.append(ops.convert_to_tensor(sf.indices))
|
||||
sparse_features_weights.append(ops.convert_to_tensor(sf.values))
|
||||
|
||||
return sdca_solver(
|
||||
sparse_features_indices,
|
||||
sparse_features_weights,
|
||||
self._convert_n_to_tensor(self._examples['dense_features']),
|
||||
tf.convert_to_tensor(self._examples['example_weights']),
|
||||
tf.convert_to_tensor(self._examples['example_labels']),
|
||||
tf.convert_to_tensor(self._variables['dual'],
|
||||
as_ref=True),
|
||||
self._convert_n_to_tensor(self._variables[
|
||||
'sparse_features_weights'],
|
||||
as_ref=True),
|
||||
self._convert_n_to_tensor(self._variables['dense_features_weights'],
|
||||
as_ref=True),
|
||||
self._training_log_loss,
|
||||
L1=self._options['symmetric_l1_regularization'],
|
||||
L2=self._options['symmetric_l2_regularization'],
|
||||
LossType=self._options['loss_type'])
|
||||
|
||||
def unregularized_loss(self, examples):
|
||||
"""Add operations to compute the loss (without the regularization loss).
|
||||
|
||||
Args:
|
||||
examples: Examples to compute unregularized loss on.
|
||||
|
||||
Returns:
|
||||
An Operation that computes mean (unregularized) loss for given set of
|
||||
examples.
|
||||
Raises:
|
||||
ValueError: if examples are not well defined.
|
||||
"""
|
||||
self._assertSpecified(
|
||||
['example_labels', 'example_weights', 'sparse_features',
|
||||
'dense_features'], examples)
|
||||
self._assertList(['sparse_features', 'dense_features'], examples)
|
||||
with tf.name_scope('sdca/unregularized_loss'):
|
||||
logits = self._logits(examples)
|
||||
# TODO(rohananil): Change loss when supporting linear regression.
|
||||
return tf.reduce_sum(tf.mul(
|
||||
tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.convert_to_tensor(
|
||||
examples['example_labels'])), tf.convert_to_tensor(examples[
|
||||
'example_weights']))) / tf.reduce_sum(ops.convert_to_tensor(
|
||||
examples['example_weights']))
|
||||
|
||||
def regularized_loss(self, examples):
|
||||
"""Add operations to compute the loss with regularization loss included.
|
||||
|
||||
Args:
|
||||
examples: Examples to compute loss on.
|
||||
|
||||
Returns:
|
||||
An Operation that computes mean (regularized) loss for given set of
|
||||
examples.
|
||||
Raises:
|
||||
ValueError: if examples are not well defined.
|
||||
"""
|
||||
self._assertSpecified(
|
||||
['example_labels', 'example_weights', 'sparse_features',
|
||||
'dense_features'], examples)
|
||||
self._assertList(['sparse_features', 'dense_features'], examples)
|
||||
with tf.name_scope('sdca/regularized_loss'):
|
||||
logits = self._logits(examples)
|
||||
# TODO(rohananil): Change loss when supporting linear regression.
|
||||
return self._l1_loss() + self._l2_loss() + self.unregularized_loss(
|
||||
examples)
|
@ -408,7 +408,7 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
// Check that this new set of fetches can be computed from all the
|
||||
// feeds we have supplied.
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckFetch(inputs, output_names, executors_and_keys->graph, run_state));
|
||||
CheckFetch(inputs, output_names, executors_and_keys, run_state));
|
||||
|
||||
// Send inputs.
|
||||
Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
|
||||
@ -510,12 +510,10 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
|
||||
Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
||||
const std::vector<string>& fetches,
|
||||
const Graph* graph,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
const RunState* run_state) {
|
||||
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node;
|
||||
for (Node* n : graph->nodes()) {
|
||||
name_to_node[n->name()] = n;
|
||||
}
|
||||
const Graph* graph = executors_and_keys->graph;
|
||||
const NameNodeMap* name_to_node = executors_and_keys->name_to_node;
|
||||
|
||||
// Build the set of pending feeds that we haven't seen.
|
||||
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
|
||||
@ -523,8 +521,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
||||
mutex_lock l(executor_lock_);
|
||||
for (const string& feed : run_state->pending_inputs) {
|
||||
TensorId id(ParseTensorName(feed));
|
||||
auto it = name_to_node.find(id.first);
|
||||
if (it == name_to_node.end()) {
|
||||
auto it = name_to_node->find(id.first);
|
||||
if (it == name_to_node->end()) {
|
||||
return errors::NotFound("Feed ", feed, ": not found");
|
||||
}
|
||||
pending_feeds.insert(id);
|
||||
@ -539,8 +537,8 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
||||
std::vector<const Node*> stack;
|
||||
for (const string& fetch : fetches) {
|
||||
TensorId id(ParseTensorName(fetch));
|
||||
auto it = name_to_node.find(id.first);
|
||||
if (it == name_to_node.end()) {
|
||||
auto it = name_to_node->find(id.first);
|
||||
if (it == name_to_node->end()) {
|
||||
return errors::NotFound("Fetch ", fetch, ": not found");
|
||||
}
|
||||
stack.push_back(it->second);
|
||||
@ -618,6 +616,10 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
ek->func_defs = fdefs;
|
||||
if (run_state_args->is_partial_run) {
|
||||
ek->graph = run_state_args->graph;
|
||||
ek->name_to_node = new NameNodeMap;
|
||||
for (Node* n : run_state_args->graph->nodes()) {
|
||||
ek->name_to_node->insert({n->name(), n});
|
||||
}
|
||||
}
|
||||
ek->items.reserve(graphs.size());
|
||||
auto runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
|
||||
|
@ -49,6 +49,8 @@ class DirectSession : public Session {
|
||||
~DirectSession() override;
|
||||
|
||||
typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
|
||||
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher>
|
||||
NameNodeMap;
|
||||
|
||||
::tensorflow::Status Create(const GraphDef& graph) override;
|
||||
::tensorflow::Status Extend(const GraphDef& graph) override;
|
||||
@ -81,13 +83,15 @@ class DirectSession : public Session {
|
||||
// An ExecutorsAndKeys is created for a given set of feeds/fetches.
|
||||
// 'func_defs' are the function definition used by all the
|
||||
// underlying executors. 'graph' is the entire graph being
|
||||
// executed. Each item in 'items' is the executor for a
|
||||
// partition of the graph bundled with its dependent library
|
||||
// runtime. 'input_keys' are the rendezvous keys for the feeds and
|
||||
// 'output_keys' are rendezvous keys for the fetches.
|
||||
// executed. 'name_to_node' maps node name to node. We keep 'graph'
|
||||
// and 'name_to_node' only in the case of partial runs. Each item in
|
||||
// 'items' is the executor for a partition of the graph bundled with
|
||||
// its dependent library runtime. 'input_keys' are the rendezvous keys
|
||||
// for the feeds and 'output_keys' are rendezvous keys for the fetches.
|
||||
struct ExecutorsAndKeys {
|
||||
FunctionLibraryDefinition* func_defs = nullptr;
|
||||
Graph* graph = nullptr;
|
||||
NameNodeMap* name_to_node = nullptr;
|
||||
std::vector<PerPartitionExecutorsAndLib> items;
|
||||
std::unordered_map<string, string> input_keys;
|
||||
std::unordered_map<string, string> output_keys;
|
||||
@ -99,6 +103,7 @@ class DirectSession : public Session {
|
||||
}
|
||||
delete func_defs;
|
||||
delete graph;
|
||||
delete name_to_node;
|
||||
}
|
||||
};
|
||||
|
||||
@ -177,8 +182,8 @@ class DirectSession : public Session {
|
||||
// that we have already provided.
|
||||
::tensorflow::Status CheckFetch(
|
||||
const std::vector<std::pair<string, Tensor>>& feeds,
|
||||
const std::vector<string>& fetches, const Graph* graph,
|
||||
const RunState* run_state);
|
||||
const std::vector<string>& fetches,
|
||||
const ExecutorsAndKeys* executors_and_keys, const RunState* run_state);
|
||||
|
||||
const SessionOptions options_;
|
||||
|
||||
|
@ -82,7 +82,8 @@ def main(_):
|
||||
|
||||
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
|
||||
merged = tf.merge_all_summaries()
|
||||
writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph_def)
|
||||
writer = tf.train.SummaryWriter(FLAGS.summaries_dir,
|
||||
sess.graph.as_graph_def(add_shapes=True))
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
# Train the model, and feed in test data and record summaries every 10 steps
|
||||
|
@ -77,6 +77,11 @@ $ sudo easy_install --upgrade six
|
||||
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.7.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
NOTE: If you are upgrading from a previous installation of TensorFlow < 0.7.1,
|
||||
you should uninstall the previous TensorFlow *and protobuf* using `pip
|
||||
uninstall` first to make sure you get a clean installation of the updated
|
||||
protobuf dependency.
|
||||
|
||||
|
||||
You can now [test your installation](#test-the-tensorflow-installation).
|
||||
|
||||
@ -582,6 +587,19 @@ explicitly.
|
||||
|
||||
### Pip installation issues
|
||||
|
||||
#### Cannot import name 'descriptor'
|
||||
|
||||
```python
|
||||
ImportError: Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.4/dist-packages/tensorflow/core/framework/graph_pb2.py", line 6, in <module>
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
ImportError: cannot import name 'descriptor'
|
||||
```
|
||||
|
||||
If you the above error when upgrading to a newer version of TensorFlow, try
|
||||
uninstalling both TensorFlow and protobuf (if installed) and re-installing
|
||||
TensorFlow (which will also install the correct protobuf dependency).
|
||||
|
||||
#### Can't find setup.py
|
||||
|
||||
If, during `pip install`, you encounter an error like:
|
||||
|
@ -228,3 +228,26 @@ The images below give an illustration for a piece of a real-life graph.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Tensor shape information
|
||||
|
||||
When the serialized `GraphDef` includes tensor shapes, the graph visualizer
|
||||
labels edges with tensor dimensions, and edge thickness reflects total tensor
|
||||
size. To include tensor shapes in the `GraphDef` pass
|
||||
`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter` when
|
||||
serializing the graph. The images below show the CIFAR-10 model with tensor
|
||||
shape information:
|
||||
<table width="100%;">
|
||||
<tr>
|
||||
<td style="width: 100%;">
|
||||
<img src="../../images/tensor_shapes.png" alt="CIFAR-10 model with tensor shape information" title="CIFAR-10 model with tensor shape information" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="width: 100%;">
|
||||
CIFAR-10 model with tensor shape information.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
|
@ -57,6 +57,10 @@ The `SummaryWriter` takes a logdir in its constructor - this logdir is quite
|
||||
important, it's the directory where all of the events will be written out.
|
||||
Also, the `SummaryWriter` can optionally take a `GraphDef` in its constructor.
|
||||
If it receives one, then TensorBoard will visualize your graph as well.
|
||||
To include tensor shape information in the `GraphDef`, pass
|
||||
`sess.graph.as_graph_def(add_shapes=True)` to the `SummaryWriter`. This will
|
||||
give you a much better sense of what flows through the graph: see
|
||||
[Tensor shape information](../../how_tos/graph_viz/index.md#tensor-shape-information).
|
||||
|
||||
Now that you've modified your graph and have a `SummaryWriter`, you're ready to
|
||||
start running your network! If you want, you could run the merged summary op
|
||||
@ -102,7 +106,8 @@ with tf.name_scope("test") as scope:
|
||||
|
||||
# Merge all the summaries and write them out to /tmp/mnist_logs
|
||||
merged = tf.merge_all_summaries()
|
||||
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph_def)
|
||||
writer = tf.train.SummaryWriter("/tmp/mnist_logs",
|
||||
sess.graph.as_graph_def(add_shapes=True))
|
||||
tf.initialize_all_variables().run()
|
||||
|
||||
# Train the model, and feed in test data and record summaries every 10 steps
|
||||
|
@ -136,7 +136,7 @@ def evaluate():
|
||||
# Build the summary operation based on the TF collection of Summaries.
|
||||
summary_op = tf.merge_all_summaries()
|
||||
|
||||
graph_def = tf.get_default_graph().as_graph_def()
|
||||
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
|
||||
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir,
|
||||
graph_def=graph_def)
|
||||
|
||||
|
@ -239,8 +239,9 @@ def train():
|
||||
# Start the queue runners.
|
||||
tf.train.start_queue_runners(sess=sess)
|
||||
|
||||
graph_def = sess.graph.as_graph_def(add_shapes=True)
|
||||
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
|
||||
graph_def=sess.graph_def)
|
||||
graph_def=graph_def)
|
||||
|
||||
for step in xrange(FLAGS.max_steps):
|
||||
start_time = time.time()
|
||||
|
@ -93,8 +93,9 @@ def train():
|
||||
# Start the queue runners.
|
||||
tf.train.start_queue_runners(sess=sess)
|
||||
|
||||
graph_def = sess.graph.as_graph_def(add_shapes=True)
|
||||
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
|
||||
graph_def=sess.graph_def)
|
||||
graph_def=graph_def)
|
||||
|
||||
for step in xrange(FLAGS.max_steps):
|
||||
start_time = time.time()
|
||||
|
Loading…
Reference in New Issue
Block a user