Fix bug where attrs with values that are the empty list

were not being properly set via the Python API.
Change: 111635679
This commit is contained in:
Josh Levenberg 2016-01-07 18:37:54 -08:00 committed by Vijay Vasudevan
parent d38fecedf5
commit 02dff6d0d8
23 changed files with 670 additions and 42 deletions

View File

@ -11,6 +11,14 @@
safety is handled by `saturate_cast`, which makes sure over- and underflows safety is handled by `saturate_cast`, which makes sure over- and underflows
are handled before casting to data types with smaller ranges. are handled before casting to data types with smaller ranges.
## Bug fixes
* The Python API will now properly set the `list` member of `AttrValue` in
constructed `GraphDef` messages for empty lists. The serialization of some
graphs will change, but the change is both forwards and backwards compatible.
It will break tests that compare a generated `GraphDef` to a golden serialized
`GraphDef`.
# Release 0.6.0 # Release 0.6.0
## Major Features and Improvements ## Major Features and Improvements

View File

@ -121,6 +121,19 @@ class OpKernel {
Status InputRange(const string& input_name, int* start, int* stop) const; Status InputRange(const string& input_name, int* start, int* stop) const;
Status OutputRange(const string& output_name, int* start, int* stop) const; Status OutputRange(const string& output_name, int* start, int* stop) const;
// TODO(irving): At the moment, the following three functions forward to
// TensorShapeUtils, but they are about to become the only versions once we
// become scalar strict.
bool allow_legacy_scalars() const { return kAllowLegacyScalars; }
bool IsLegacyScalar(const TensorShape& shape) const {
return TensorShapeUtils::IsLegacyScalar(shape);
}
bool IsLegacyVector(const TensorShape& shape) const {
return TensorShapeUtils::IsLegacyVector(shape);
}
private: private:
const NodeDef def_; const NodeDef def_;
const DataTypeVector input_types_; const DataTypeVector input_types_;
@ -455,6 +468,8 @@ class OpKernelContext {
Env* env() const { return params_.device->env(); } Env* env() const { return params_.device->env(); }
const OpKernel& op_kernel() const { return *params_.op_kernel; }
// Input/output signature. // Input/output signature.
int num_inputs() const { return params_.inputs->size(); } int num_inputs() const { return params_.inputs->size(); }

View File

@ -45,7 +45,7 @@ class ConcatOp : public OpKernel {
const Tensor* concat_dim_tensor; const Tensor* concat_dim_tensor;
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor)); OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()), c, IsLegacyScalar(concat_dim_tensor->shape()),
errors::InvalidArgument( errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ", "Concat dim tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString())); concat_dim_tensor->shape().DebugString()));
@ -57,7 +57,7 @@ class ConcatOp : public OpKernel {
const TensorShape& input_shape = values[0].shape(); const TensorShape& input_shape = values[0].shape();
OP_REQUIRES( OP_REQUIRES(
c, (0 <= concat_dim && concat_dim < input_dims) || c, (0 <= concat_dim && concat_dim < input_dims) ||
(kAllowLegacyScalars && concat_dim == 0), (allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument( errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [", 0, "ConcatOp : Expected concatenating dimensions in the range [", 0,
", ", input_dims, "), but got ", concat_dim)); ", ", input_dims, "), but got ", concat_dim));
@ -74,10 +74,10 @@ class ConcatOp : public OpKernel {
inputs_flat_dim0 *= input_shape.dim_size(d); inputs_flat_dim0 *= input_shape.dim_size(d);
} }
int output_concat_dim = 0; int output_concat_dim = 0;
const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape); const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
const auto in = values[i]; const auto in = values[i];
const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape()); const bool in_is_scalar = IsLegacyScalar(in.shape());
OP_REQUIRES( OP_REQUIRES(
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar), c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument( errors::InvalidArgument(
@ -100,12 +100,12 @@ class ConcatOp : public OpKernel {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
} }
// TODO(irving): Remove check once !kAllowLegacyScalars // TODO(irving): Remove check once !allow_legacy_scalars().
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1; output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
} }
TensorShape output_shape(input_shape); TensorShape output_shape(input_shape);
// TODO(irving): Remove rank 0 case once !kAllowLegacyScalars // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
if (output_shape.dims() == 0) { if (output_shape.dims() == 0) {
output_shape.AddDim(output_concat_dim); output_shape.AddDim(output_concat_dim);
} else { } else {

View File

@ -143,11 +143,14 @@ class FillOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor& Tdims = context->input(0); const Tensor& Tdims = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()), OP_REQUIRES(
errors::InvalidArgument("dims must be a vector of int32.")); context, IsLegacyVector(Tdims.shape()),
errors::InvalidArgument("dims must be a vector of int32, got shape ",
Tdims.shape().ShortDebugString()));
const Tensor& Tvalue = context->input(1); const Tensor& Tvalue = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()), OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
errors::InvalidArgument("value must be a scalar.")); errors::InvalidArgument("value must be a scalar, got shape ",
Tvalue.shape().ShortDebugString()));
auto dims = Tdims.flat<int32>(); auto dims = Tdims.flat<int32>();
for (int i = 0; i < dims.size(); i++) { for (int i = 0; i < dims.size(); i++) {
OP_REQUIRES(context, dims(i) >= 0, OP_REQUIRES(context, dims(i) >= 0,

View File

@ -28,7 +28,7 @@ class AssertOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
const Tensor& cond = ctx->input(0); const Tensor& cond = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()), OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
errors::InvalidArgument("In[0] should be a scalar: ", errors::InvalidArgument("In[0] should be a scalar: ",
cond.shape().ShortDebugString())); cond.shape().ShortDebugString()));

View File

@ -59,7 +59,8 @@ class PadOp : public OpKernel {
errors::InvalidArgument("paddings must be a matrix with 2 columns: ", errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
in1.shape().DebugString())); in1.shape().DebugString()));
const int fixed_dims = const int fixed_dims =
(kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims; (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
: dims;
OP_REQUIRES( OP_REQUIRES(
context, fixed_dims == in1.dim_size(0), context, fixed_dims == in1.dim_size(0),
errors::InvalidArgument( errors::InvalidArgument(
@ -76,7 +77,7 @@ class PadOp : public OpKernel {
errors::InvalidArgument("Paddings must be non-negative: ", errors::InvalidArgument("Paddings must be non-negative: ",
before_d, " ", after_d)); before_d, " ", after_d));
const int size_d = const int size_d =
(kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d); (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
output_shape.AddDim(before_d + size_d + after_d); output_shape.AddDim(before_d + size_d + after_d);
} }
Tensor* output = nullptr; Tensor* output = nullptr;
@ -89,7 +90,7 @@ class PadOp : public OpKernel {
break; break;
case 1: case 1:
// TODO(irving): Once Pad doesn't need a scalar special case, // TODO(irving): Once Pad doesn't need a scalar special case,
// change flat to tensor. That is, once !kAllowLegacyScalars. // change flat to tensor. That is, once !allow_legacy_scalars().
Operate<1>(context, in0.flat<T>(), paddings, output); Operate<1>(context, in0.flat<T>(), paddings, output);
break; break;
case 2: case 2:

View File

@ -180,7 +180,7 @@ namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) { int index, Tensor** output) {
if (!TensorShapeUtils::IsLegacyVector(shape.shape())) { if (!ctx->op_kernel().IsLegacyVector(shape.shape())) {
return errors::InvalidArgument( return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ", "shape must be a vector of {int32,int64}, got shape ",
shape.shape().ShortDebugString()); shape.shape().ShortDebugString());

View File

@ -35,7 +35,7 @@ class ReshapeOp : public OpKernel {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
const Tensor& sizes = context->input(1); const Tensor& sizes = context->input(1);
// Preliminary validation of sizes. // Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()), OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ", errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().ShortDebugString())); sizes.shape().ShortDebugString()));
const int64 num_dims = sizes.NumElements(); const int64 num_dims = sizes.NumElements();

View File

@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
static const char* input_names[3] = {"basename", "shard", "num_shards"}; static const char* input_names[3] = {"basename", "shard", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) { for (int i = 0; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()), OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
errors::InvalidArgument( errors::InvalidArgument(
input_names[i], " must be a scalar, got shape ", input_names[i], " must be a scalar, got shape ",
ctx->input(i).shape().ShortDebugString())); ctx->input(i).shape().ShortDebugString()));
@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
static const char* input_names[2] = {"basename", "num_shards"}; static const char* input_names[2] = {"basename", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) { for (int i = 0; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()), OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
errors::InvalidArgument( errors::InvalidArgument(
input_names[i], " must be a scalar, got shape ", input_names[i], " must be a scalar, got shape ",
ctx->input(i).shape().ShortDebugString())); ctx->input(i).shape().ShortDebugString()));

View File

@ -184,7 +184,7 @@ class UnsortedSegmentSumOp : public OpKernel {
const Tensor& num_segments = context->input(2); const Tensor& num_segments = context->input(2);
OP_REQUIRES( OP_REQUIRES(
context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()), context, IsLegacyScalar(num_segments.shape()),
errors::InvalidArgument("num_segments should be a scalar, not shape ", errors::InvalidArgument("num_segments should be a scalar, not shape ",
num_segments.shape().ShortDebugString())); num_segments.shape().ShortDebugString()));
@ -406,7 +406,7 @@ class SparseSegmentMeanGradOp : public OpKernel {
errors::InvalidArgument("indices should be a vector.")); errors::InvalidArgument("indices should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
errors::InvalidArgument("segment_ids should be a vector.")); errors::InvalidArgument("segment_ids should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()), OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
errors::InvalidArgument("output_dim0 should be a scalar.")); errors::InvalidArgument("output_dim0 should be a scalar."));
const int64 N = indices.NumElements(); const int64 N = indices.NumElements();

View File

@ -34,13 +34,13 @@ class RangeOp : public OpKernel {
const Tensor& start_in = context->input(0); const Tensor& start_in = context->input(0);
const Tensor& limit_in = context->input(1); const Tensor& limit_in = context->input(1);
const Tensor& delta_in = context->input(2); const Tensor& delta_in = context->input(2);
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()), OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
errors::InvalidArgument("start must be a scalar, not shape ", errors::InvalidArgument("start must be a scalar, not shape ",
start_in.shape().ShortDebugString())); start_in.shape().ShortDebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()), OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
errors::InvalidArgument("limit must be a scalar, not shape ", errors::InvalidArgument("limit must be a scalar, not shape ",
limit_in.shape().ShortDebugString())); limit_in.shape().ShortDebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()), OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
errors::InvalidArgument("delta must be a scalar, not shape ", errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in.shape().ShortDebugString())); delta_in.shape().ShortDebugString()));
const int32 start = GetValue(start_in.scalar<T>()()); const int32 start = GetValue(start_in.scalar<T>()());

View File

@ -69,14 +69,15 @@ static void SharedValidation(OpKernelContext* context,
const Tensor& size_tensor = context->input(2); const Tensor& size_tensor = context->input(2);
OP_REQUIRES( OP_REQUIRES(
context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) && context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
TensorShapeUtils::IsLegacyVector(size_tensor.shape()) && context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input.dims() && begin_tensor.NumElements() == input.dims() &&
size_tensor.NumElements() == input.dims(), size_tensor.NumElements() == input.dims(),
errors::InvalidArgument( errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ", "Expected begin and size arguments to be 1-D tensors of size ",
input.dims(), ", but got ", begin_tensor.NumElements(), " and ", input.dims(), ", but got shapes ",
size_tensor.NumElements(), " instead.")); begin_tensor.shape().ShortDebugString(), " and ",
size_tensor.shape().ShortDebugString(), " instead."));
const int input_dims = input.dims(); const int input_dims = input.dims();
*begin = IntTensorToInt64Vec(begin_tensor); *begin = IntTensorToInt64Vec(begin_tensor);

View File

@ -60,7 +60,7 @@ class SparseToDense : public OpKernel {
// output_shape // output_shape
const Tensor& output_shape = c->input(1); const Tensor& output_shape = c->input(1);
OP_REQUIRES( OP_REQUIRES(
c, TensorShapeUtils::IsLegacyVector(output_shape.shape()), c, IsLegacyVector(output_shape.shape()),
errors::InvalidArgument("output_shape should be a vector, got shape ", errors::InvalidArgument("output_shape should be a vector, got shape ",
output_shape.shape().ShortDebugString())); output_shape.shape().ShortDebugString()));
OP_REQUIRES(c, output_shape.NumElements() == num_dims, OP_REQUIRES(c, output_shape.NumElements() == num_dims,

View File

@ -48,8 +48,8 @@ class SummaryImageOp : public OpKernel {
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0); const Tensor& tags = c->input(0);
const Tensor& tensor = c->input(1); const Tensor& tensor = c->input(1);
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()), OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
errors::InvalidArgument("Tags must have be a scalar")); errors::InvalidArgument("Tags must be a scalar"));
OP_REQUIRES(c, tensor.dims() == 4 && OP_REQUIRES(c, tensor.dims() == 4 &&
(tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
tensor.dim_size(3) == 4), tensor.dim_size(3) == 4),

View File

@ -40,12 +40,12 @@ class SummaryScalarOp : public OpKernel {
const Tensor& tags = c->input(0); const Tensor& tags = c->input(0);
const Tensor& values = c->input(1); const Tensor& values = c->input(1);
OP_REQUIRES(c, tags.IsSameSize(values) || OP_REQUIRES(c, tags.IsSameSize(values) || (IsLegacyScalar(tags.shape()) &&
(TensorShapeUtils::IsLegacyScalar(tags.shape()) && IsLegacyScalar(values.shape())),
TensorShapeUtils::IsLegacyScalar(values.shape())),
errors::InvalidArgument("tags and values not the same shape: ", errors::InvalidArgument("tags and values not the same shape: ",
tags.shape().ShortDebugString(), " != ", tags.shape().ShortDebugString(), " != ",
values.shape().ShortDebugString())); values.shape().ShortDebugString(),
SingleTag(tags)));
auto Ttags = tags.flat<string>(); auto Ttags = tags.flat<string>();
auto Tvalues = values.flat<T>(); auto Tvalues = values.flat<T>();
Summary s; Summary s;
@ -59,6 +59,15 @@ class SummaryScalarOp : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()())); CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
} }
// If there's only one tag, include it in the error message
static string SingleTag(const Tensor& tags) {
if (tags.NumElements() == 1) {
return strings::StrCat(" (tag '", tags.flat<string>()(0), "')");
} else {
return "";
}
}
}; };
template <typename T> template <typename T>
@ -72,7 +81,7 @@ class SummaryHistoOp : public OpKernel {
const Tensor& tags = c->input(0); const Tensor& tags = c->input(0);
const Tensor& values = c->input(1); const Tensor& values = c->input(1);
const auto flat = values.flat<T>(); const auto flat = values.flat<T>();
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()), OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
errors::InvalidArgument("tags must be scalar")); errors::InvalidArgument("tags must be scalar"));
// Build histogram of values in "values" tensor // Build histogram of values in "values" tensor
histogram::Histogram histo; histogram::Histogram histo;

View File

@ -46,7 +46,7 @@ class TileOp : public OpKernel {
const Tensor& multiples = context->input(1); const Tensor& multiples = context->input(1);
OP_REQUIRES( OP_REQUIRES(
context, TensorShapeUtils::IsLegacyVector(multiples.shape()), context, IsLegacyVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().ShortDebugString())); multiples.shape().ShortDebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(), OP_REQUIRES(context, input.dims() == multiples.NumElements(),
@ -192,7 +192,7 @@ class TileGradientOp : public OpKernel {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
const Tensor& multiples = context->input(1); const Tensor& multiples = context->input(1);
OP_REQUIRES( OP_REQUIRES(
context, TensorShapeUtils::IsLegacyVector(multiples.shape()), context, IsLegacyVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ", errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().ShortDebugString())); multiples.shape().ShortDebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(), OP_REQUIRES(context, input.dims() == multiples.NumElements(),

View File

@ -153,7 +153,7 @@ class ApplyGradientDescentOp : public OpKernel {
errors::FailedPrecondition( errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(0))); "Attempting to use uninitialized variables: ", def().input(0)));
const Tensor& alpha = ctx->input(1); const Tensor& alpha = ctx->input(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()), OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
errors::InvalidArgument("alpha is not a scalar: ", errors::InvalidArgument("alpha is not a scalar: ",
alpha.shape().DebugString())); alpha.shape().DebugString()));
const Tensor& delta = ctx->input(2); const Tensor& delta = ctx->input(2);
@ -242,7 +242,7 @@ class ApplyAdagradOp : public OpKernel {
errors::FailedPrecondition( errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(1))); "Attempting to use uninitialized variables: ", def().input(1)));
const Tensor& lr = ctx->input(2); const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ", errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString())); lr.shape().DebugString()));
const Tensor& grad = ctx->input(3); const Tensor& grad = ctx->input(3);
@ -336,7 +336,7 @@ class SparseApplyAdagradOp : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional")); errors::InvalidArgument("var must be at least 1 dimensional"));
const Tensor& lr = ctx->input(2); const Tensor& lr = ctx->input(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()), OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ", errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString())); lr.shape().DebugString()));
const Tensor& grad = ctx->input(3); const Tensor& grad = ctx->input(3);

View File

@ -683,6 +683,7 @@ py_library(
"ops/image_ops.py", "ops/image_ops.py",
"ops/init_ops.py", "ops/init_ops.py",
"ops/io_ops.py", "ops/io_ops.py",
"ops/learn.py",
"ops/linalg_grad.py", "ops/linalg_grad.py",
"ops/linalg_ops.py", "ops/linalg_ops.py",
"ops/logging_ops.py", "ops/logging_ops.py",

View File

@ -57,7 +57,8 @@ from tensorflow.python.client.client_lib import *
# Ops # Ops
from tensorflow.python.ops.standard_ops import * from tensorflow.python.ops.standard_ops import *
# Bring nn, image_ops, user_ops, compat as a subpackages # Bring learn, nn, image_ops, user_ops, compat as a subpackages
from tensorflow.python.ops import learn
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import image_ops as image from tensorflow.python.ops import image_ops as image
from tensorflow.python.user_ops import user_ops from tensorflow.python.user_ops import user_ops
@ -77,7 +78,7 @@ from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test from tensorflow.python.platform import test
# Don't export modules except for the few we really want # Don't export modules except for the few we really want
_whitelist = set([app, compat, errors, flags, image, logging, nn, _whitelist = set([app, compat, errors, flags, image, learn, logging, nn,
python_io, resource_loader, test, train, user_ops]) python_io, resource_loader, test, train, user_ops])
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid. # TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable _whitelist.update([tensor_util]) # pylint: disable=undefined-variable

View File

@ -3159,6 +3159,8 @@ class GraphKeys(object):
keep moving averages. See keep moving averages. See
[`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables) [`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables)
for more details. for more details.
* `REGULARIZATION_LOSSES`: regularization losses collected during graph
construction.
""" """
# Key to collect Variable objects that must be saved and restored # Key to collect Variable objects that must be saved and restored
@ -3178,6 +3180,8 @@ class GraphKeys(object):
ASSET_FILEPATHS = "asset_filepaths" ASSET_FILEPATHS = "asset_filepaths"
# Key to collect Variable objects that keep moving averages. # Key to collect Variable objects that keep moving averages.
MOVING_AVERAGE_VARIABLES = "moving_average_variables" MOVING_AVERAGE_VARIABLES = "moving_average_variables"
# Key to collected regularization losses at graph construction.
REGULARIZATION_LOSSES = "regularization_losses"
def add_to_collection(name, value): def add_to_collection(name, value):

View File

@ -0,0 +1,225 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.learn."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import tensorflow.python.platform # pylint: disable=unused-import,g-bad-import-order
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.framework import tensor_util
class FullyConnectedTest(tf.test.TestCase):
def setUp(self):
tf.test.TestCase.setUp(self)
tf.set_random_seed(1234)
self.input = tf.constant([[1., 2., 3.], [-4., 5., -6.]])
assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
def assert_summary_scope(self, regexp):
for summary in tf.get_collection(tf.GraphKeys.SUMMARIES):
tag = tensor_util.ConstantValue(summary.op.inputs[0])
assert tag is not None, 'All summaries have constant tags'
tag = str(tag)
assert isinstance(tag[0], six.string_types), tag[0]
assert re.match(regexp, tag), "tag doesn't match %s: %s" % (regexp, tag)
def test_basic_use(self):
output = tf.learn.fully_connected(self.input, 8, activation_fn=tf.nn.relu)
with tf.Session() as sess:
with self.assertRaises(tf.errors.FailedPreconditionError):
sess.run(output)
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 8])
self.assertTrue(np.all(out_value >= 0),
'Relu should have capped all values.')
self.assertGreater(tf.get_collection(tf.GraphKeys.SUMMARIES), 0,
'Some summaries should have been added.')
self.assertEqual(2,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(0,
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
self.assert_summary_scope('fully_connected')
def test_variable_reuse_with_scope(self):
with tf.variable_scope('test') as vs:
output1 = tf.learn.fully_connected(self.input,
8,
activation_fn=tf.nn.relu)
output2 = tf.learn.fully_connected(self.input,
8,
activation_fn=tf.nn.relu)
with tf.variable_scope(vs, reuse=True):
output3 = tf.learn.fully_connected(self.input,
8,
activation_fn=tf.nn.relu)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
self.assertFalse(np.allclose(out_value1, out_value2))
self.assertAllClose(out_value1, out_value3)
def test_variable_reuse_with_template(self):
tmpl1 = tf.make_template('test',
tf.learn.fully_connected,
num_output_nodes=8)
output1 = tmpl1(self.input)
output2 = tmpl1(self.input)
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value1, out_value2 = sess.run([output1, output2])
self.assertAllClose(out_value1, out_value2)
self.assert_summary_scope(r'test(_\d)?/fully_connected')
def test_custom_initializers(self):
output = tf.learn.fully_connected(self.input,
2,
activation_fn=tf.nn.relu,
weight_init=tf.constant_initializer(2.0),
bias_init=tf.constant_initializer(1.0))
with tf.Session() as sess:
tf.initialize_all_variables().run()
out_value = sess.run(output)
self.assertAllClose(np.array([[13.0, 13.0], [0.0, 0.0]]), out_value)
def test_custom_collections(self):
tf.learn.fully_connected(self.input,
2,
activation_fn=tf.nn.relu,
weight_collections=['unbiased'],
bias_collections=['biased'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
def test_all_custom_collections(self):
tf.learn.fully_connected(self.input,
2,
activation_fn=tf.nn.relu,
weight_collections=['unbiased', 'all'],
bias_collections=['biased', 'all'])
self.assertEquals(1, len(tf.get_collection('unbiased')))
self.assertEquals(1, len(tf.get_collection('biased')))
self.assertEquals(2, len(tf.get_collection('all')))
self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
tf.get_collection('all'))
def test_no_summaries(self):
tf.learn.fully_connected(self.input,
2,
activation_fn=tf.nn.relu,
create_summaries=False)
self.assertEquals([], tf.get_collection(tf.GraphKeys.SUMMARIES))
def test_regularizer(self):
cnt = [0]
tensor = tf.constant(5.0)
def test_fn(_):
cnt[0] += 1
return tensor
tf.learn.fully_connected(self.input, 2, weight_regularizer=test_fn)
self.assertEqual([tensor],
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
def test_shape_enforcement(self):
place = tf.placeholder(tf.float32)
with self.assertRaises(ValueError):
tf.learn.fully_connected(place, 8)
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
place.set_shape([None, None])
with self.assertRaises(ValueError):
tf.learn.fully_connected(place, 8)
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
place.set_shape([None, 6])
tf.learn.fully_connected(place, 8) # No error
with self.assertRaises(ValueError):
tf.learn.fully_connected(place, 8, num_input_nodes=5)
place = tf.placeholder(tf.float32)
place.set_shape([2, 6, 5])
with self.assertRaises(ValueError):
tf.learn.fully_connected(place, 8)
def test_no_bias(self):
tf.learn.fully_connected(self.input, 2, bias_init=None)
self.assertEqual(1,
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
class RegularizerTest(tf.test.TestCase):
def test_l1(self):
with self.assertRaises(ValueError):
tf.learn.l1_regularizer(2.)
with self.assertRaises(ValueError):
tf.learn.l1_regularizer(-1.)
with self.assertRaises(ValueError):
tf.learn.l1_regularizer(0)
self.assertIsNone(tf.learn.l1_regularizer(0.)(None))
values = np.array([1., -1., 4., 2.])
weights = tf.constant(values)
with tf.Session() as sess:
result = sess.run(tf.learn.l1_regularizer(.5)(weights))
self.assertAllClose(np.abs(values).sum() * .5, result)
def test_l2(self):
with self.assertRaises(ValueError):
tf.learn.l2_regularizer(2.)
with self.assertRaises(ValueError):
tf.learn.l2_regularizer(-1.)
with self.assertRaises(ValueError):
tf.learn.l2_regularizer(0)
self.assertIsNone(tf.learn.l2_regularizer(0.)(None))
values = np.array([1., -1., 4., 2.])
weights = tf.constant(values)
with tf.Session() as sess:
result = sess.run(tf.learn.l2_regularizer(.42)(weights))
self.assertAllClose(np.power(values, 2).sum() / 2.0 * .42, result)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,359 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-short-docstring-punctuation
"""## Higher level ops related to regularization and building layers.
This package provides several ops that take care of creating variables that are
used internally in a consistent way and provide the building blocks for many
common machine learning algorithms.
@@fully_connected
## Regularizers
Regularization can help prevent overfitting.
These have the signature `fn(weights)`. The loss is typically added to
`tf.GraphKeys.REGULARIZATION_LOSS`
@@l1_regularizer
@@l2_regularizer
## Initializations
This also includes a common initialization for connecting multiple layers.
@@xavier_initializer
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numbers
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import logging
__all__ = ['xavier_initializer', 'fully_connected', 'l1_regularizer',
'l2_regularizer']
def xavier_initializer(n_inputs, n_outputs, uniform=True):
"""Set the parameter initialization using the method described in paper.
Xavier Glorot and Yoshua Bengio (2010):
Understanding the difficulty of training deep feedforward neural
networks. International conference on artificial intelligence and
statistics.
This method is designed to keep the scale of the gradients roughly the same
in all layers. In uniform distribution this ends up being the range:
`x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard
deviation of `sqrt(3. / (in + out))` is used.
Args:
n_inputs: The number of input nodes into each output.
n_outputs: The number of output nodes for each input.
uniform: If true use a uniform distribution, otherwise use a truncated
normal.
Returns:
An initializer.
"""
if uniform:
# 6 was used in the paper.
init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
return standard_ops.random_uniform_initializer(-init_range, init_range)
else:
# 3 gives us approximately the same limits as above since this repicks
# values greater than 2 standard deviations from the mean.
stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
return standard_ops.truncated_normal_initializer(stddev=stddev)
def _assert_summary_tag_unique(tag):
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
old_tag = tensor_util.ConstantValue(summary.op.inputs[0])
if tag == str(old_tag):
raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
(tag, summary, old_tag))
def _add_scalar_summary(tensor, tag=None):
"""Add a summary operation for the tensor.
Args:
tensor: The tensor to summarize.
tag: The tag to use, if None then use tensor's op's name.
Returns:
The created histogram summary.
Raises:
ValueError: If the tag is already in use or the rank is not 0.
"""
tensor.get_shape().assert_has_rank(0)
tag = tag or tensor.op.name
_assert_summary_tag_unique(tag)
return standard_ops.scalar_summary(tag, tensor, name='%s_summary' % tag)
def _add_histogram_summary(tensor, tag=None):
"""Add a summary operation for the histogram of a tensor.
Args:
tensor: The tensor to summarize.
tag: The tag to use, if None then use tensor's op's name.
Returns:
The created histogram summary.
Raises:
ValueError: If the tag is already in use.
"""
# TODO(opensource): A global or scoped mechanism to disable summaries.
tag = tag or tensor.op.name
_assert_summary_tag_unique(tag)
return standard_ops.histogram_summary(tag, tensor, name='%s_summary' % tag)
def _apply_activation_with_summaries(x, activation_fn):
"""Returns activation_fn(x).
This applies the given activation and adds useful summaries specific to the
activation.
Args:
x: The tensor to apply activation to.
activation_fn: An activation function.
Returns:
A tensor with activation applied to x.
"""
if activation_fn is None:
return x
y = activation_fn(x)
if activation_fn in (nn.relu, nn.softplus, nn.relu6):
# Using x for comparison to avoid floating point equality and/or epsilons.
_add_scalar_summary(
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.less(
x, 0.0))), '%s/zeros' % y.op.name)
if activation_fn is nn.relu6:
_add_scalar_summary(
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.greater(
x, 6.0))), '%s/sixes' % y.op.name)
if activation_fn is nn.l2_normalize:
_add_scalar_summary(
standard_ops.reduce_mean(standard_ops.sqrt(standard_ops.sum(
standard_ops.square(x), 1))), '%s/length' % y.op.name)
_add_histogram_summary(y, '%s/activations' % y.op.name)
return y
def _apply_regularization(w, regularizer):
loss = regularizer(w)
if loss:
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
def l1_regularizer(scale):
"""Returns a function that can be used to apply L1 regularization to weights.
L1 regularization encourages sparsity.
Args:
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
Returns:
A function with signature `l1(weights, name=None)` that apply L1
regularization.
Raises:
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
float.
"""
if isinstance(scale, numbers.Integral):
raise ValueError('scale cannot be an integer: %s' % scale)
if isinstance(scale, numbers.Real):
if scale < 0.:
raise ValueError('Setting a scale less than 0 on a regularizer: %g' %
scale)
if scale >= 1.:
raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
scale)
if scale == 0.:
logging.info('Scale of 0 disables regularizer.')
return lambda _, name=None: None
def l1(weights, name=None):
"""Applies L1 regularization to weights."""
with ops.op_scope([weights], name, 'l1_regularizer') as scope:
my_scale = ops.convert_to_tensor(scale,
dtype=weights.dtype.base_dtype,
name='scale')
return standard_ops.mul(
my_scale,
standard_ops.reduce_sum(standard_ops.abs(weights)),
name=scope)
return l1
def l2_regularizer(scale):
"""Returns a function that can be used to apply L2 regularization to weights.
Small values of L2 can help prevent overfitting the training data.
Args:
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
Returns:
A function with signature `l2(weights, name=None)` that applies L2
regularization.
Raises:
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
float.
"""
if isinstance(scale, numbers.Integral):
raise ValueError('scale cannot be an integer: %s' % (scale,))
if isinstance(scale, numbers.Real):
if scale < 0.:
raise ValueError('Setting a scale less than 0 on a regularizer: %g.' %
scale)
if scale >= 1.:
raise ValueError('Setting a scale greater than 1 on a regularizer: %g.' %
scale)
if scale == 0.:
logging.info('Scale of 0 disables regularizer.')
return lambda _, name=None: None
def l2(weights, name=None):
"""Applies l2 regularization to weights."""
with ops.op_scope([weights], name, 'l2_regularizer') as scope:
my_scale = ops.convert_to_tensor(scale,
dtype=weights.dtype.base_dtype,
name='scale')
return standard_ops.mul(my_scale, nn.l2_loss(weights), name=scope)
return l2
def fully_connected(x,
num_output_nodes,
activation_fn=None,
weight_init=None,
bias_init=standard_ops.constant_initializer(0.),
num_input_nodes=None,
name=None,
weight_collections=None,
bias_collections=None,
weight_regularizer=None,
create_summaries=True):
"""Adds the parameters for a fully connected layer and returns the output.
A fully connected layer is generally defined as a matrix multiply:
\\\\(y = f(w * x + b)\\\\) where **f** is given by `activation_fn`
This op creates `w` and optionally `b` (disable with `bias_init=None`) and
adds various summaries that can be useful for visualizing learning or
diagnosing training problems. The variable creation is compatible with
`tf.variable_scope` and so can be reused with `tf.variable_scope` or
`tf.make_template`.
In almost all cases, the number of input nodes can be inferred from the shape
of `x`, but if it is unspecified or additional size checks are desired, then
`num_input_nodes` can be specified.
Most of the details of variable creation can be controlled by specifying the
initializers (`weight_init` and `bias_init`) and which collections to place
the created variables in (`weight_collections` and `bias_collections`).
A per layer regularization can be specified by setting `weight_regularizer`.
This is only applied to weights and not the bias.
Args:
x: The input tensor.
num_output_nodes: The size of the output.
activation_fn: A function that requires a single Tensor that is applied as a
non-linearity. If None is used, then this is a linear layer.
weight_init: An optional initialization. If not specified, uses Xavier
initialization (see `tf.learn.xavier_initializer`).
bias_init: An initializer for the bias, defaults to 0.
num_input_nodes: The number of input nodes.
name: The name for this operation is used to name operations and to find
variables. If specified it must be unique for this scope, otherwise a
unique name starting with "fully_connected" will be created. See
`tf.variable_op_scope` for details.
weight_collections: List of graph collections for just weights.
bias_collections: List of graph collections for just bias.
weight_regularizer: A regularizer like the result of
`tf.learn.l1_regularizer` or `tf.learn.l2_regularizer`.
create_summaries: Set to false to disable summaries.
Returns:
The result of applying a fully connected layer.
Raises:
ValueError: if `x` is not rank 2; or `x`'s second dimension is not known
and `num_input_nodes` is not specified.
"""
with variable_scope.variable_op_scope([x], name, 'fully_connected') as vs:
# Check rank and if num_input_nodes is specified, make sure it matches.
x.get_shape().assert_is_compatible_with([None, num_input_nodes])
if not num_input_nodes:
if x.get_shape().dims is None or x.get_shape().dims[1].value is None:
raise ValueError(
'If x has an unknown first dimension then num_input_nodes '
'must be specified; shape: %s num_input_nodes: %s'
% (x.get_shape(), num_input_nodes))
else:
num_input_nodes = x.get_shape().dims[1].value
weight_init = weight_init or xavier_initializer(
num_input_nodes, num_output_nodes)
dtype = x.dtype
w = variable_scope.get_variable('weights',
shape=[num_input_nodes, num_output_nodes],
dtype=dtype,
initializer=weight_init,
collections=weight_collections)
if not vs.reuse and create_summaries:
_add_histogram_summary(w)
y = standard_ops.matmul(x, w)
# Regularization is only applied to the weights and not bias.
if weight_regularizer:
_apply_regularization(w, weight_regularizer)
if bias_init:
b = variable_scope.get_variable('bias',
shape=[num_output_nodes],
dtype=dtype,
initializer=bias_init,
collections=bias_collections)
if not vs.reuse and create_summaries:
_add_histogram_summary(b)
y = nn.bias_add(y, b)
if create_summaries:
return _apply_activation_with_summaries(y, activation_fn)
else:
return activation_fn(y)

View File

@ -559,6 +559,7 @@ class OpDefLibrary(object):
"less than minimum %d." % "less than minimum %d." %
(key, op_type_name, len(value), (key, op_type_name, len(value),
attr_def.minimum)) attr_def.minimum))
attr_value.list.SetInParent()
if attr_def.type == "string": if attr_def.type == "string":
attr_value.s = _MakeStr(value, key) attr_value.s = _MakeStr(value, key)
if attr_def.HasField("allowed_values"): if attr_def.HasField("allowed_values"):