Fix bug where attrs with values that are the empty list
were not being properly set via the Python API. Change: 111635679
This commit is contained in:
parent
d38fecedf5
commit
02dff6d0d8
@ -11,6 +11,14 @@
|
||||
safety is handled by `saturate_cast`, which makes sure over- and underflows
|
||||
are handled before casting to data types with smaller ranges.
|
||||
|
||||
## Bug fixes
|
||||
|
||||
* The Python API will now properly set the `list` member of `AttrValue` in
|
||||
constructed `GraphDef` messages for empty lists. The serialization of some
|
||||
graphs will change, but the change is both forwards and backwards compatible.
|
||||
It will break tests that compare a generated `GraphDef` to a golden serialized
|
||||
`GraphDef`.
|
||||
|
||||
# Release 0.6.0
|
||||
|
||||
## Major Features and Improvements
|
||||
|
@ -121,6 +121,19 @@ class OpKernel {
|
||||
Status InputRange(const string& input_name, int* start, int* stop) const;
|
||||
Status OutputRange(const string& output_name, int* start, int* stop) const;
|
||||
|
||||
// TODO(irving): At the moment, the following three functions forward to
|
||||
// TensorShapeUtils, but they are about to become the only versions once we
|
||||
// become scalar strict.
|
||||
bool allow_legacy_scalars() const { return kAllowLegacyScalars; }
|
||||
|
||||
bool IsLegacyScalar(const TensorShape& shape) const {
|
||||
return TensorShapeUtils::IsLegacyScalar(shape);
|
||||
}
|
||||
|
||||
bool IsLegacyVector(const TensorShape& shape) const {
|
||||
return TensorShapeUtils::IsLegacyVector(shape);
|
||||
}
|
||||
|
||||
private:
|
||||
const NodeDef def_;
|
||||
const DataTypeVector input_types_;
|
||||
@ -455,6 +468,8 @@ class OpKernelContext {
|
||||
|
||||
Env* env() const { return params_.device->env(); }
|
||||
|
||||
const OpKernel& op_kernel() const { return *params_.op_kernel; }
|
||||
|
||||
// Input/output signature.
|
||||
|
||||
int num_inputs() const { return params_.inputs->size(); }
|
||||
|
@ -45,7 +45,7 @@ class ConcatOp : public OpKernel {
|
||||
const Tensor* concat_dim_tensor;
|
||||
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()),
|
||||
c, IsLegacyScalar(concat_dim_tensor->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Concat dim tensor should be a scalar integer, but got shape ",
|
||||
concat_dim_tensor->shape().DebugString()));
|
||||
@ -57,7 +57,7 @@ class ConcatOp : public OpKernel {
|
||||
const TensorShape& input_shape = values[0].shape();
|
||||
OP_REQUIRES(
|
||||
c, (0 <= concat_dim && concat_dim < input_dims) ||
|
||||
(kAllowLegacyScalars && concat_dim == 0),
|
||||
(allow_legacy_scalars() && concat_dim == 0),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Expected concatenating dimensions in the range [", 0,
|
||||
", ", input_dims, "), but got ", concat_dim));
|
||||
@ -74,10 +74,10 @@ class ConcatOp : public OpKernel {
|
||||
inputs_flat_dim0 *= input_shape.dim_size(d);
|
||||
}
|
||||
int output_concat_dim = 0;
|
||||
const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape);
|
||||
const bool input_is_scalar = IsLegacyScalar(input_shape);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const auto in = values[i];
|
||||
const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape());
|
||||
const bool in_is_scalar = IsLegacyScalar(in.shape());
|
||||
OP_REQUIRES(
|
||||
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
|
||||
errors::InvalidArgument(
|
||||
@ -100,12 +100,12 @@ class ConcatOp : public OpKernel {
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||
}
|
||||
// TODO(irving): Remove check once !kAllowLegacyScalars
|
||||
// TODO(irving): Remove check once !allow_legacy_scalars().
|
||||
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
|
||||
}
|
||||
|
||||
TensorShape output_shape(input_shape);
|
||||
// TODO(irving): Remove rank 0 case once !kAllowLegacyScalars
|
||||
// TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
|
||||
if (output_shape.dims() == 0) {
|
||||
output_shape.AddDim(output_concat_dim);
|
||||
} else {
|
||||
|
@ -143,11 +143,14 @@ class FillOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& Tdims = context->input(0);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()),
|
||||
errors::InvalidArgument("dims must be a vector of int32."));
|
||||
OP_REQUIRES(
|
||||
context, IsLegacyVector(Tdims.shape()),
|
||||
errors::InvalidArgument("dims must be a vector of int32, got shape ",
|
||||
Tdims.shape().ShortDebugString()));
|
||||
const Tensor& Tvalue = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()),
|
||||
errors::InvalidArgument("value must be a scalar."));
|
||||
OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
|
||||
errors::InvalidArgument("value must be a scalar, got shape ",
|
||||
Tvalue.shape().ShortDebugString()));
|
||||
auto dims = Tdims.flat<int32>();
|
||||
for (int i = 0; i < dims.size(); i++) {
|
||||
OP_REQUIRES(context, dims(i) >= 0,
|
||||
|
@ -28,7 +28,7 @@ class AssertOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& cond = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
|
||||
errors::InvalidArgument("In[0] should be a scalar: ",
|
||||
cond.shape().ShortDebugString()));
|
||||
|
||||
|
@ -59,7 +59,8 @@ class PadOp : public OpKernel {
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
in1.shape().DebugString()));
|
||||
const int fixed_dims =
|
||||
(kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims;
|
||||
(allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
context, fixed_dims == in1.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
@ -76,7 +77,7 @@ class PadOp : public OpKernel {
|
||||
errors::InvalidArgument("Paddings must be non-negative: ",
|
||||
before_d, " ", after_d));
|
||||
const int size_d =
|
||||
(kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d);
|
||||
(allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
|
||||
output_shape.AddDim(before_d + size_d + after_d);
|
||||
}
|
||||
Tensor* output = nullptr;
|
||||
@ -89,7 +90,7 @@ class PadOp : public OpKernel {
|
||||
break;
|
||||
case 1:
|
||||
// TODO(irving): Once Pad doesn't need a scalar special case,
|
||||
// change flat to tensor. That is, once !kAllowLegacyScalars.
|
||||
// change flat to tensor. That is, once !allow_legacy_scalars().
|
||||
Operate<1>(context, in0.flat<T>(), paddings, output);
|
||||
break;
|
||||
case 2:
|
||||
|
@ -180,7 +180,7 @@ namespace {
|
||||
|
||||
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
|
||||
int index, Tensor** output) {
|
||||
if (!TensorShapeUtils::IsLegacyVector(shape.shape())) {
|
||||
if (!ctx->op_kernel().IsLegacyVector(shape.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"shape must be a vector of {int32,int64}, got shape ",
|
||||
shape.shape().ShortDebugString());
|
||||
|
@ -35,7 +35,7 @@ class ReshapeOp : public OpKernel {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& sizes = context->input(1);
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()),
|
||||
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
|
||||
errors::InvalidArgument("sizes input must be 1-D, not shape ",
|
||||
sizes.shape().ShortDebugString()));
|
||||
const int64 num_dims = sizes.NumElements();
|
||||
|
@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
static const char* input_names[3] = {"basename", "shard", "num_shards"};
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
|
||||
errors::InvalidArgument(
|
||||
input_names[i], " must be a scalar, got shape ",
|
||||
ctx->input(i).shape().ShortDebugString()));
|
||||
@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
static const char* input_names[2] = {"basename", "num_shards"};
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
|
||||
errors::InvalidArgument(
|
||||
input_names[i], " must be a scalar, got shape ",
|
||||
ctx->input(i).shape().ShortDebugString()));
|
||||
|
@ -184,7 +184,7 @@ class UnsortedSegmentSumOp : public OpKernel {
|
||||
const Tensor& num_segments = context->input(2);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()),
|
||||
context, IsLegacyScalar(num_segments.shape()),
|
||||
errors::InvalidArgument("num_segments should be a scalar, not shape ",
|
||||
num_segments.shape().ShortDebugString()));
|
||||
|
||||
@ -406,7 +406,7 @@ class SparseSegmentMeanGradOp : public OpKernel {
|
||||
errors::InvalidArgument("indices should be a vector."));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
||||
errors::InvalidArgument("segment_ids should be a vector."));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()),
|
||||
OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
|
||||
errors::InvalidArgument("output_dim0 should be a scalar."));
|
||||
|
||||
const int64 N = indices.NumElements();
|
||||
|
@ -34,13 +34,13 @@ class RangeOp : public OpKernel {
|
||||
const Tensor& start_in = context->input(0);
|
||||
const Tensor& limit_in = context->input(1);
|
||||
const Tensor& delta_in = context->input(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()),
|
||||
OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
|
||||
errors::InvalidArgument("start must be a scalar, not shape ",
|
||||
start_in.shape().ShortDebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()),
|
||||
OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
|
||||
errors::InvalidArgument("limit must be a scalar, not shape ",
|
||||
limit_in.shape().ShortDebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()),
|
||||
OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
|
||||
errors::InvalidArgument("delta must be a scalar, not shape ",
|
||||
delta_in.shape().ShortDebugString()));
|
||||
const int32 start = GetValue(start_in.scalar<T>()());
|
||||
|
@ -69,14 +69,15 @@ static void SharedValidation(OpKernelContext* context,
|
||||
const Tensor& size_tensor = context->input(2);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) &&
|
||||
TensorShapeUtils::IsLegacyVector(size_tensor.shape()) &&
|
||||
context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
|
||||
context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
|
||||
begin_tensor.NumElements() == input.dims() &&
|
||||
size_tensor.NumElements() == input.dims(),
|
||||
errors::InvalidArgument(
|
||||
"Expected begin and size arguments to be 1-D tensors of size ",
|
||||
input.dims(), ", but got ", begin_tensor.NumElements(), " and ",
|
||||
size_tensor.NumElements(), " instead."));
|
||||
input.dims(), ", but got shapes ",
|
||||
begin_tensor.shape().ShortDebugString(), " and ",
|
||||
size_tensor.shape().ShortDebugString(), " instead."));
|
||||
|
||||
const int input_dims = input.dims();
|
||||
*begin = IntTensorToInt64Vec(begin_tensor);
|
||||
|
@ -60,7 +60,7 @@ class SparseToDense : public OpKernel {
|
||||
// output_shape
|
||||
const Tensor& output_shape = c->input(1);
|
||||
OP_REQUIRES(
|
||||
c, TensorShapeUtils::IsLegacyVector(output_shape.shape()),
|
||||
c, IsLegacyVector(output_shape.shape()),
|
||||
errors::InvalidArgument("output_shape should be a vector, got shape ",
|
||||
output_shape.shape().ShortDebugString()));
|
||||
OP_REQUIRES(c, output_shape.NumElements() == num_dims,
|
||||
|
@ -48,8 +48,8 @@ class SummaryImageOp : public OpKernel {
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& tags = c->input(0);
|
||||
const Tensor& tensor = c->input(1);
|
||||
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
|
||||
errors::InvalidArgument("Tags must have be a scalar"));
|
||||
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
|
||||
errors::InvalidArgument("Tags must be a scalar"));
|
||||
OP_REQUIRES(c, tensor.dims() == 4 &&
|
||||
(tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
|
||||
tensor.dim_size(3) == 4),
|
||||
|
@ -40,12 +40,12 @@ class SummaryScalarOp : public OpKernel {
|
||||
const Tensor& tags = c->input(0);
|
||||
const Tensor& values = c->input(1);
|
||||
|
||||
OP_REQUIRES(c, tags.IsSameSize(values) ||
|
||||
(TensorShapeUtils::IsLegacyScalar(tags.shape()) &&
|
||||
TensorShapeUtils::IsLegacyScalar(values.shape())),
|
||||
OP_REQUIRES(c, tags.IsSameSize(values) || (IsLegacyScalar(tags.shape()) &&
|
||||
IsLegacyScalar(values.shape())),
|
||||
errors::InvalidArgument("tags and values not the same shape: ",
|
||||
tags.shape().ShortDebugString(), " != ",
|
||||
values.shape().ShortDebugString()));
|
||||
values.shape().ShortDebugString(),
|
||||
SingleTag(tags)));
|
||||
auto Ttags = tags.flat<string>();
|
||||
auto Tvalues = values.flat<T>();
|
||||
Summary s;
|
||||
@ -59,6 +59,15 @@ class SummaryScalarOp : public OpKernel {
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
|
||||
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
|
||||
}
|
||||
|
||||
// If there's only one tag, include it in the error message
|
||||
static string SingleTag(const Tensor& tags) {
|
||||
if (tags.NumElements() == 1) {
|
||||
return strings::StrCat(" (tag '", tags.flat<string>()(0), "')");
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -72,7 +81,7 @@ class SummaryHistoOp : public OpKernel {
|
||||
const Tensor& tags = c->input(0);
|
||||
const Tensor& values = c->input(1);
|
||||
const auto flat = values.flat<T>();
|
||||
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
|
||||
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
|
||||
errors::InvalidArgument("tags must be scalar"));
|
||||
// Build histogram of values in "values" tensor
|
||||
histogram::Histogram histo;
|
||||
|
@ -46,7 +46,7 @@ class TileOp : public OpKernel {
|
||||
const Tensor& multiples = context->input(1);
|
||||
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
|
||||
context, IsLegacyVector(multiples.shape()),
|
||||
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
||||
multiples.shape().ShortDebugString()));
|
||||
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
||||
@ -192,7 +192,7 @@ class TileGradientOp : public OpKernel {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& multiples = context->input(1);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
|
||||
context, IsLegacyVector(multiples.shape()),
|
||||
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
||||
multiples.shape().ShortDebugString()));
|
||||
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
||||
|
@ -153,7 +153,7 @@ class ApplyGradientDescentOp : public OpKernel {
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(0)));
|
||||
const Tensor& alpha = ctx->input(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
|
||||
errors::InvalidArgument("alpha is not a scalar: ",
|
||||
alpha.shape().DebugString()));
|
||||
const Tensor& delta = ctx->input(2);
|
||||
@ -242,7 +242,7 @@ class ApplyAdagradOp : public OpKernel {
|
||||
errors::FailedPrecondition(
|
||||
"Attempting to use uninitialized variables: ", def().input(1)));
|
||||
const Tensor& lr = ctx->input(2);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
const Tensor& grad = ctx->input(3);
|
||||
@ -336,7 +336,7 @@ class SparseApplyAdagradOp : public OpKernel {
|
||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||
|
||||
const Tensor& lr = ctx->input(2);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
|
||||
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||
errors::InvalidArgument("lr is not a scalar: ",
|
||||
lr.shape().DebugString()));
|
||||
const Tensor& grad = ctx->input(3);
|
||||
|
@ -683,6 +683,7 @@ py_library(
|
||||
"ops/image_ops.py",
|
||||
"ops/init_ops.py",
|
||||
"ops/io_ops.py",
|
||||
"ops/learn.py",
|
||||
"ops/linalg_grad.py",
|
||||
"ops/linalg_ops.py",
|
||||
"ops/logging_ops.py",
|
||||
|
@ -57,7 +57,8 @@ from tensorflow.python.client.client_lib import *
|
||||
# Ops
|
||||
from tensorflow.python.ops.standard_ops import *
|
||||
|
||||
# Bring nn, image_ops, user_ops, compat as a subpackages
|
||||
# Bring learn, nn, image_ops, user_ops, compat as a subpackages
|
||||
from tensorflow.python.ops import learn
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import image_ops as image
|
||||
from tensorflow.python.user_ops import user_ops
|
||||
@ -77,7 +78,7 @@ from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
# Don't export modules except for the few we really want
|
||||
_whitelist = set([app, compat, errors, flags, image, logging, nn,
|
||||
_whitelist = set([app, compat, errors, flags, image, learn, logging, nn,
|
||||
python_io, resource_loader, test, train, user_ops])
|
||||
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
|
||||
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable
|
||||
|
@ -3159,6 +3159,8 @@ class GraphKeys(object):
|
||||
keep moving averages. See
|
||||
[`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables)
|
||||
for more details.
|
||||
* `REGULARIZATION_LOSSES`: regularization losses collected during graph
|
||||
construction.
|
||||
"""
|
||||
|
||||
# Key to collect Variable objects that must be saved and restored
|
||||
@ -3178,6 +3180,8 @@ class GraphKeys(object):
|
||||
ASSET_FILEPATHS = "asset_filepaths"
|
||||
# Key to collect Variable objects that keep moving averages.
|
||||
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
|
||||
# Key to collected regularization losses at graph construction.
|
||||
REGULARIZATION_LOSSES = "regularization_losses"
|
||||
|
||||
|
||||
def add_to_collection(name, value):
|
||||
|
225
tensorflow/python/kernel_tests/learn_test.py
Normal file
225
tensorflow/python/kernel_tests/learn_test.py
Normal file
@ -0,0 +1,225 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tf.learn."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
import tensorflow.python.platform # pylint: disable=unused-import,g-bad-import-order
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
|
||||
class FullyConnectedTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tf.test.TestCase.setUp(self)
|
||||
tf.set_random_seed(1234)
|
||||
self.input = tf.constant([[1., 2., 3.], [-4., 5., -6.]])
|
||||
assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
|
||||
|
||||
def assert_summary_scope(self, regexp):
|
||||
for summary in tf.get_collection(tf.GraphKeys.SUMMARIES):
|
||||
tag = tensor_util.ConstantValue(summary.op.inputs[0])
|
||||
assert tag is not None, 'All summaries have constant tags'
|
||||
tag = str(tag)
|
||||
assert isinstance(tag[0], six.string_types), tag[0]
|
||||
assert re.match(regexp, tag), "tag doesn't match %s: %s" % (regexp, tag)
|
||||
|
||||
def test_basic_use(self):
|
||||
output = tf.learn.fully_connected(self.input, 8, activation_fn=tf.nn.relu)
|
||||
|
||||
with tf.Session() as sess:
|
||||
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||
sess.run(output)
|
||||
|
||||
tf.initialize_all_variables().run()
|
||||
out_value = sess.run(output)
|
||||
|
||||
self.assertEqual(output.get_shape().as_list(), [2, 8])
|
||||
self.assertTrue(np.all(out_value >= 0),
|
||||
'Relu should have capped all values.')
|
||||
|
||||
self.assertGreater(tf.get_collection(tf.GraphKeys.SUMMARIES), 0,
|
||||
'Some summaries should have been added.')
|
||||
self.assertEqual(2,
|
||||
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
|
||||
self.assertEqual(0,
|
||||
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
|
||||
self.assert_summary_scope('fully_connected')
|
||||
|
||||
def test_variable_reuse_with_scope(self):
|
||||
with tf.variable_scope('test') as vs:
|
||||
output1 = tf.learn.fully_connected(self.input,
|
||||
8,
|
||||
activation_fn=tf.nn.relu)
|
||||
output2 = tf.learn.fully_connected(self.input,
|
||||
8,
|
||||
activation_fn=tf.nn.relu)
|
||||
|
||||
with tf.variable_scope(vs, reuse=True):
|
||||
output3 = tf.learn.fully_connected(self.input,
|
||||
8,
|
||||
activation_fn=tf.nn.relu)
|
||||
|
||||
with tf.Session() as sess:
|
||||
tf.initialize_all_variables().run()
|
||||
out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
|
||||
|
||||
self.assertFalse(np.allclose(out_value1, out_value2))
|
||||
self.assertAllClose(out_value1, out_value3)
|
||||
|
||||
def test_variable_reuse_with_template(self):
|
||||
tmpl1 = tf.make_template('test',
|
||||
tf.learn.fully_connected,
|
||||
num_output_nodes=8)
|
||||
output1 = tmpl1(self.input)
|
||||
output2 = tmpl1(self.input)
|
||||
|
||||
with tf.Session() as sess:
|
||||
tf.initialize_all_variables().run()
|
||||
out_value1, out_value2 = sess.run([output1, output2])
|
||||
self.assertAllClose(out_value1, out_value2)
|
||||
self.assert_summary_scope(r'test(_\d)?/fully_connected')
|
||||
|
||||
def test_custom_initializers(self):
|
||||
output = tf.learn.fully_connected(self.input,
|
||||
2,
|
||||
activation_fn=tf.nn.relu,
|
||||
weight_init=tf.constant_initializer(2.0),
|
||||
bias_init=tf.constant_initializer(1.0))
|
||||
|
||||
with tf.Session() as sess:
|
||||
tf.initialize_all_variables().run()
|
||||
out_value = sess.run(output)
|
||||
|
||||
self.assertAllClose(np.array([[13.0, 13.0], [0.0, 0.0]]), out_value)
|
||||
|
||||
def test_custom_collections(self):
|
||||
tf.learn.fully_connected(self.input,
|
||||
2,
|
||||
activation_fn=tf.nn.relu,
|
||||
weight_collections=['unbiased'],
|
||||
bias_collections=['biased'])
|
||||
|
||||
self.assertEquals(1, len(tf.get_collection('unbiased')))
|
||||
self.assertEquals(1, len(tf.get_collection('biased')))
|
||||
|
||||
def test_all_custom_collections(self):
|
||||
tf.learn.fully_connected(self.input,
|
||||
2,
|
||||
activation_fn=tf.nn.relu,
|
||||
weight_collections=['unbiased', 'all'],
|
||||
bias_collections=['biased', 'all'])
|
||||
|
||||
self.assertEquals(1, len(tf.get_collection('unbiased')))
|
||||
self.assertEquals(1, len(tf.get_collection('biased')))
|
||||
self.assertEquals(2, len(tf.get_collection('all')))
|
||||
self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
|
||||
tf.get_collection('all'))
|
||||
|
||||
def test_no_summaries(self):
|
||||
tf.learn.fully_connected(self.input,
|
||||
2,
|
||||
activation_fn=tf.nn.relu,
|
||||
create_summaries=False)
|
||||
self.assertEquals([], tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||
|
||||
def test_regularizer(self):
|
||||
cnt = [0]
|
||||
tensor = tf.constant(5.0)
|
||||
def test_fn(_):
|
||||
cnt[0] += 1
|
||||
return tensor
|
||||
|
||||
tf.learn.fully_connected(self.input, 2, weight_regularizer=test_fn)
|
||||
|
||||
self.assertEqual([tensor],
|
||||
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
||||
self.assertEqual(1, cnt[0])
|
||||
|
||||
def test_shape_enforcement(self):
|
||||
place = tf.placeholder(tf.float32)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.fully_connected(place, 8)
|
||||
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
|
||||
|
||||
place.set_shape([None, None])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.fully_connected(place, 8)
|
||||
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
|
||||
|
||||
place.set_shape([None, 6])
|
||||
tf.learn.fully_connected(place, 8) # No error
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.fully_connected(place, 8, num_input_nodes=5)
|
||||
|
||||
place = tf.placeholder(tf.float32)
|
||||
place.set_shape([2, 6, 5])
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.fully_connected(place, 8)
|
||||
|
||||
def test_no_bias(self):
|
||||
tf.learn.fully_connected(self.input, 2, bias_init=None)
|
||||
|
||||
self.assertEqual(1,
|
||||
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
|
||||
|
||||
|
||||
class RegularizerTest(tf.test.TestCase):
|
||||
|
||||
def test_l1(self):
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l1_regularizer(2.)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l1_regularizer(-1.)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l1_regularizer(0)
|
||||
|
||||
self.assertIsNone(tf.learn.l1_regularizer(0.)(None))
|
||||
|
||||
values = np.array([1., -1., 4., 2.])
|
||||
weights = tf.constant(values)
|
||||
with tf.Session() as sess:
|
||||
result = sess.run(tf.learn.l1_regularizer(.5)(weights))
|
||||
|
||||
self.assertAllClose(np.abs(values).sum() * .5, result)
|
||||
|
||||
def test_l2(self):
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l2_regularizer(2.)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l2_regularizer(-1.)
|
||||
with self.assertRaises(ValueError):
|
||||
tf.learn.l2_regularizer(0)
|
||||
|
||||
self.assertIsNone(tf.learn.l2_regularizer(0.)(None))
|
||||
|
||||
values = np.array([1., -1., 4., 2.])
|
||||
weights = tf.constant(values)
|
||||
with tf.Session() as sess:
|
||||
result = sess.run(tf.learn.l2_regularizer(.42)(weights))
|
||||
|
||||
self.assertAllClose(np.power(values, 2).sum() / 2.0 * .42, result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
359
tensorflow/python/ops/learn.py
Normal file
359
tensorflow/python/ops/learn.py
Normal file
@ -0,0 +1,359 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=g-short-docstring-punctuation
|
||||
"""## Higher level ops related to regularization and building layers.
|
||||
|
||||
This package provides several ops that take care of creating variables that are
|
||||
used internally in a consistent way and provide the building blocks for many
|
||||
common machine learning algorithms.
|
||||
|
||||
@@fully_connected
|
||||
|
||||
## Regularizers
|
||||
|
||||
Regularization can help prevent overfitting.
|
||||
These have the signature `fn(weights)`. The loss is typically added to
|
||||
`tf.GraphKeys.REGULARIZATION_LOSS`
|
||||
|
||||
@@l1_regularizer
|
||||
@@l2_regularizer
|
||||
|
||||
## Initializations
|
||||
|
||||
This also includes a common initialization for connecting multiple layers.
|
||||
|
||||
@@xavier_initializer
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numbers
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import standard_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import logging
|
||||
|
||||
|
||||
__all__ = ['xavier_initializer', 'fully_connected', 'l1_regularizer',
|
||||
'l2_regularizer']
|
||||
|
||||
|
||||
def xavier_initializer(n_inputs, n_outputs, uniform=True):
|
||||
"""Set the parameter initialization using the method described in paper.
|
||||
|
||||
Xavier Glorot and Yoshua Bengio (2010):
|
||||
Understanding the difficulty of training deep feedforward neural
|
||||
networks. International conference on artificial intelligence and
|
||||
statistics.
|
||||
|
||||
This method is designed to keep the scale of the gradients roughly the same
|
||||
in all layers. In uniform distribution this ends up being the range:
|
||||
`x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard
|
||||
deviation of `sqrt(3. / (in + out))` is used.
|
||||
|
||||
Args:
|
||||
n_inputs: The number of input nodes into each output.
|
||||
n_outputs: The number of output nodes for each input.
|
||||
uniform: If true use a uniform distribution, otherwise use a truncated
|
||||
normal.
|
||||
|
||||
Returns:
|
||||
An initializer.
|
||||
"""
|
||||
if uniform:
|
||||
# 6 was used in the paper.
|
||||
init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
|
||||
return standard_ops.random_uniform_initializer(-init_range, init_range)
|
||||
else:
|
||||
# 3 gives us approximately the same limits as above since this repicks
|
||||
# values greater than 2 standard deviations from the mean.
|
||||
stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
|
||||
return standard_ops.truncated_normal_initializer(stddev=stddev)
|
||||
|
||||
|
||||
def _assert_summary_tag_unique(tag):
|
||||
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
|
||||
old_tag = tensor_util.ConstantValue(summary.op.inputs[0])
|
||||
if tag == str(old_tag):
|
||||
raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
|
||||
(tag, summary, old_tag))
|
||||
|
||||
|
||||
def _add_scalar_summary(tensor, tag=None):
|
||||
"""Add a summary operation for the tensor.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to summarize.
|
||||
tag: The tag to use, if None then use tensor's op's name.
|
||||
|
||||
Returns:
|
||||
The created histogram summary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tag is already in use or the rank is not 0.
|
||||
"""
|
||||
tensor.get_shape().assert_has_rank(0)
|
||||
tag = tag or tensor.op.name
|
||||
_assert_summary_tag_unique(tag)
|
||||
return standard_ops.scalar_summary(tag, tensor, name='%s_summary' % tag)
|
||||
|
||||
|
||||
def _add_histogram_summary(tensor, tag=None):
|
||||
"""Add a summary operation for the histogram of a tensor.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to summarize.
|
||||
tag: The tag to use, if None then use tensor's op's name.
|
||||
|
||||
Returns:
|
||||
The created histogram summary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the tag is already in use.
|
||||
"""
|
||||
# TODO(opensource): A global or scoped mechanism to disable summaries.
|
||||
tag = tag or tensor.op.name
|
||||
_assert_summary_tag_unique(tag)
|
||||
return standard_ops.histogram_summary(tag, tensor, name='%s_summary' % tag)
|
||||
|
||||
|
||||
def _apply_activation_with_summaries(x, activation_fn):
|
||||
"""Returns activation_fn(x).
|
||||
|
||||
This applies the given activation and adds useful summaries specific to the
|
||||
activation.
|
||||
|
||||
Args:
|
||||
x: The tensor to apply activation to.
|
||||
activation_fn: An activation function.
|
||||
Returns:
|
||||
A tensor with activation applied to x.
|
||||
"""
|
||||
if activation_fn is None:
|
||||
return x
|
||||
y = activation_fn(x)
|
||||
if activation_fn in (nn.relu, nn.softplus, nn.relu6):
|
||||
# Using x for comparison to avoid floating point equality and/or epsilons.
|
||||
_add_scalar_summary(
|
||||
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.less(
|
||||
x, 0.0))), '%s/zeros' % y.op.name)
|
||||
if activation_fn is nn.relu6:
|
||||
_add_scalar_summary(
|
||||
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.greater(
|
||||
x, 6.0))), '%s/sixes' % y.op.name)
|
||||
if activation_fn is nn.l2_normalize:
|
||||
_add_scalar_summary(
|
||||
standard_ops.reduce_mean(standard_ops.sqrt(standard_ops.sum(
|
||||
standard_ops.square(x), 1))), '%s/length' % y.op.name)
|
||||
_add_histogram_summary(y, '%s/activations' % y.op.name)
|
||||
return y
|
||||
|
||||
|
||||
def _apply_regularization(w, regularizer):
|
||||
loss = regularizer(w)
|
||||
if loss:
|
||||
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
|
||||
|
||||
|
||||
def l1_regularizer(scale):
|
||||
"""Returns a function that can be used to apply L1 regularization to weights.
|
||||
|
||||
L1 regularization encourages sparsity.
|
||||
|
||||
Args:
|
||||
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
|
||||
|
||||
Returns:
|
||||
A function with signature `l1(weights, name=None)` that apply L1
|
||||
regularization.
|
||||
|
||||
Raises:
|
||||
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
|
||||
float.
|
||||
"""
|
||||
if isinstance(scale, numbers.Integral):
|
||||
raise ValueError('scale cannot be an integer: %s' % scale)
|
||||
if isinstance(scale, numbers.Real):
|
||||
if scale < 0.:
|
||||
raise ValueError('Setting a scale less than 0 on a regularizer: %g' %
|
||||
scale)
|
||||
if scale >= 1.:
|
||||
raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
|
||||
scale)
|
||||
if scale == 0.:
|
||||
logging.info('Scale of 0 disables regularizer.')
|
||||
return lambda _, name=None: None
|
||||
def l1(weights, name=None):
|
||||
"""Applies L1 regularization to weights."""
|
||||
with ops.op_scope([weights], name, 'l1_regularizer') as scope:
|
||||
my_scale = ops.convert_to_tensor(scale,
|
||||
dtype=weights.dtype.base_dtype,
|
||||
name='scale')
|
||||
return standard_ops.mul(
|
||||
my_scale,
|
||||
standard_ops.reduce_sum(standard_ops.abs(weights)),
|
||||
name=scope)
|
||||
return l1
|
||||
|
||||
|
||||
def l2_regularizer(scale):
|
||||
"""Returns a function that can be used to apply L2 regularization to weights.
|
||||
|
||||
Small values of L2 can help prevent overfitting the training data.
|
||||
|
||||
Args:
|
||||
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
|
||||
|
||||
Returns:
|
||||
A function with signature `l2(weights, name=None)` that applies L2
|
||||
regularization.
|
||||
|
||||
Raises:
|
||||
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
|
||||
float.
|
||||
"""
|
||||
if isinstance(scale, numbers.Integral):
|
||||
raise ValueError('scale cannot be an integer: %s' % (scale,))
|
||||
if isinstance(scale, numbers.Real):
|
||||
if scale < 0.:
|
||||
raise ValueError('Setting a scale less than 0 on a regularizer: %g.' %
|
||||
scale)
|
||||
if scale >= 1.:
|
||||
raise ValueError('Setting a scale greater than 1 on a regularizer: %g.' %
|
||||
scale)
|
||||
if scale == 0.:
|
||||
logging.info('Scale of 0 disables regularizer.')
|
||||
return lambda _, name=None: None
|
||||
def l2(weights, name=None):
|
||||
"""Applies l2 regularization to weights."""
|
||||
with ops.op_scope([weights], name, 'l2_regularizer') as scope:
|
||||
my_scale = ops.convert_to_tensor(scale,
|
||||
dtype=weights.dtype.base_dtype,
|
||||
name='scale')
|
||||
return standard_ops.mul(my_scale, nn.l2_loss(weights), name=scope)
|
||||
return l2
|
||||
|
||||
|
||||
def fully_connected(x,
|
||||
num_output_nodes,
|
||||
activation_fn=None,
|
||||
weight_init=None,
|
||||
bias_init=standard_ops.constant_initializer(0.),
|
||||
num_input_nodes=None,
|
||||
name=None,
|
||||
weight_collections=None,
|
||||
bias_collections=None,
|
||||
weight_regularizer=None,
|
||||
create_summaries=True):
|
||||
"""Adds the parameters for a fully connected layer and returns the output.
|
||||
|
||||
A fully connected layer is generally defined as a matrix multiply:
|
||||
\\\\(y = f(w * x + b)\\\\) where **f** is given by `activation_fn`
|
||||
|
||||
This op creates `w` and optionally `b` (disable with `bias_init=None`) and
|
||||
adds various summaries that can be useful for visualizing learning or
|
||||
diagnosing training problems. The variable creation is compatible with
|
||||
`tf.variable_scope` and so can be reused with `tf.variable_scope` or
|
||||
`tf.make_template`.
|
||||
|
||||
In almost all cases, the number of input nodes can be inferred from the shape
|
||||
of `x`, but if it is unspecified or additional size checks are desired, then
|
||||
`num_input_nodes` can be specified.
|
||||
|
||||
Most of the details of variable creation can be controlled by specifying the
|
||||
initializers (`weight_init` and `bias_init`) and which collections to place
|
||||
the created variables in (`weight_collections` and `bias_collections`).
|
||||
|
||||
A per layer regularization can be specified by setting `weight_regularizer`.
|
||||
This is only applied to weights and not the bias.
|
||||
|
||||
Args:
|
||||
x: The input tensor.
|
||||
num_output_nodes: The size of the output.
|
||||
activation_fn: A function that requires a single Tensor that is applied as a
|
||||
non-linearity. If None is used, then this is a linear layer.
|
||||
weight_init: An optional initialization. If not specified, uses Xavier
|
||||
initialization (see `tf.learn.xavier_initializer`).
|
||||
bias_init: An initializer for the bias, defaults to 0.
|
||||
num_input_nodes: The number of input nodes.
|
||||
name: The name for this operation is used to name operations and to find
|
||||
variables. If specified it must be unique for this scope, otherwise a
|
||||
unique name starting with "fully_connected" will be created. See
|
||||
`tf.variable_op_scope` for details.
|
||||
weight_collections: List of graph collections for just weights.
|
||||
bias_collections: List of graph collections for just bias.
|
||||
weight_regularizer: A regularizer like the result of
|
||||
`tf.learn.l1_regularizer` or `tf.learn.l2_regularizer`.
|
||||
create_summaries: Set to false to disable summaries.
|
||||
|
||||
Returns:
|
||||
The result of applying a fully connected layer.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x` is not rank 2; or `x`'s second dimension is not known
|
||||
and `num_input_nodes` is not specified.
|
||||
"""
|
||||
with variable_scope.variable_op_scope([x], name, 'fully_connected') as vs:
|
||||
# Check rank and if num_input_nodes is specified, make sure it matches.
|
||||
x.get_shape().assert_is_compatible_with([None, num_input_nodes])
|
||||
|
||||
if not num_input_nodes:
|
||||
if x.get_shape().dims is None or x.get_shape().dims[1].value is None:
|
||||
raise ValueError(
|
||||
'If x has an unknown first dimension then num_input_nodes '
|
||||
'must be specified; shape: %s num_input_nodes: %s'
|
||||
% (x.get_shape(), num_input_nodes))
|
||||
else:
|
||||
num_input_nodes = x.get_shape().dims[1].value
|
||||
|
||||
weight_init = weight_init or xavier_initializer(
|
||||
num_input_nodes, num_output_nodes)
|
||||
|
||||
dtype = x.dtype
|
||||
w = variable_scope.get_variable('weights',
|
||||
shape=[num_input_nodes, num_output_nodes],
|
||||
dtype=dtype,
|
||||
initializer=weight_init,
|
||||
collections=weight_collections)
|
||||
|
||||
if not vs.reuse and create_summaries:
|
||||
_add_histogram_summary(w)
|
||||
|
||||
y = standard_ops.matmul(x, w)
|
||||
# Regularization is only applied to the weights and not bias.
|
||||
if weight_regularizer:
|
||||
_apply_regularization(w, weight_regularizer)
|
||||
if bias_init:
|
||||
b = variable_scope.get_variable('bias',
|
||||
shape=[num_output_nodes],
|
||||
dtype=dtype,
|
||||
initializer=bias_init,
|
||||
collections=bias_collections)
|
||||
if not vs.reuse and create_summaries:
|
||||
_add_histogram_summary(b)
|
||||
|
||||
y = nn.bias_add(y, b)
|
||||
|
||||
if create_summaries:
|
||||
return _apply_activation_with_summaries(y, activation_fn)
|
||||
else:
|
||||
return activation_fn(y)
|
@ -559,6 +559,7 @@ class OpDefLibrary(object):
|
||||
"less than minimum %d." %
|
||||
(key, op_type_name, len(value),
|
||||
attr_def.minimum))
|
||||
attr_value.list.SetInParent()
|
||||
if attr_def.type == "string":
|
||||
attr_value.s = _MakeStr(value, key)
|
||||
if attr_def.HasField("allowed_values"):
|
||||
|
Loading…
Reference in New Issue
Block a user