Fix bug where attrs with values that are the empty list
were not being properly set via the Python API. Change: 111635679
This commit is contained in:
parent
d38fecedf5
commit
02dff6d0d8
@ -11,6 +11,14 @@
|
|||||||
safety is handled by `saturate_cast`, which makes sure over- and underflows
|
safety is handled by `saturate_cast`, which makes sure over- and underflows
|
||||||
are handled before casting to data types with smaller ranges.
|
are handled before casting to data types with smaller ranges.
|
||||||
|
|
||||||
|
## Bug fixes
|
||||||
|
|
||||||
|
* The Python API will now properly set the `list` member of `AttrValue` in
|
||||||
|
constructed `GraphDef` messages for empty lists. The serialization of some
|
||||||
|
graphs will change, but the change is both forwards and backwards compatible.
|
||||||
|
It will break tests that compare a generated `GraphDef` to a golden serialized
|
||||||
|
`GraphDef`.
|
||||||
|
|
||||||
# Release 0.6.0
|
# Release 0.6.0
|
||||||
|
|
||||||
## Major Features and Improvements
|
## Major Features and Improvements
|
||||||
|
@ -121,6 +121,19 @@ class OpKernel {
|
|||||||
Status InputRange(const string& input_name, int* start, int* stop) const;
|
Status InputRange(const string& input_name, int* start, int* stop) const;
|
||||||
Status OutputRange(const string& output_name, int* start, int* stop) const;
|
Status OutputRange(const string& output_name, int* start, int* stop) const;
|
||||||
|
|
||||||
|
// TODO(irving): At the moment, the following three functions forward to
|
||||||
|
// TensorShapeUtils, but they are about to become the only versions once we
|
||||||
|
// become scalar strict.
|
||||||
|
bool allow_legacy_scalars() const { return kAllowLegacyScalars; }
|
||||||
|
|
||||||
|
bool IsLegacyScalar(const TensorShape& shape) const {
|
||||||
|
return TensorShapeUtils::IsLegacyScalar(shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsLegacyVector(const TensorShape& shape) const {
|
||||||
|
return TensorShapeUtils::IsLegacyVector(shape);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const NodeDef def_;
|
const NodeDef def_;
|
||||||
const DataTypeVector input_types_;
|
const DataTypeVector input_types_;
|
||||||
@ -455,6 +468,8 @@ class OpKernelContext {
|
|||||||
|
|
||||||
Env* env() const { return params_.device->env(); }
|
Env* env() const { return params_.device->env(); }
|
||||||
|
|
||||||
|
const OpKernel& op_kernel() const { return *params_.op_kernel; }
|
||||||
|
|
||||||
// Input/output signature.
|
// Input/output signature.
|
||||||
|
|
||||||
int num_inputs() const { return params_.inputs->size(); }
|
int num_inputs() const { return params_.inputs->size(); }
|
||||||
|
@ -45,7 +45,7 @@ class ConcatOp : public OpKernel {
|
|||||||
const Tensor* concat_dim_tensor;
|
const Tensor* concat_dim_tensor;
|
||||||
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
|
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()),
|
c, IsLegacyScalar(concat_dim_tensor->shape()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Concat dim tensor should be a scalar integer, but got shape ",
|
"Concat dim tensor should be a scalar integer, but got shape ",
|
||||||
concat_dim_tensor->shape().DebugString()));
|
concat_dim_tensor->shape().DebugString()));
|
||||||
@ -57,7 +57,7 @@ class ConcatOp : public OpKernel {
|
|||||||
const TensorShape& input_shape = values[0].shape();
|
const TensorShape& input_shape = values[0].shape();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, (0 <= concat_dim && concat_dim < input_dims) ||
|
c, (0 <= concat_dim && concat_dim < input_dims) ||
|
||||||
(kAllowLegacyScalars && concat_dim == 0),
|
(allow_legacy_scalars() && concat_dim == 0),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"ConcatOp : Expected concatenating dimensions in the range [", 0,
|
"ConcatOp : Expected concatenating dimensions in the range [", 0,
|
||||||
", ", input_dims, "), but got ", concat_dim));
|
", ", input_dims, "), but got ", concat_dim));
|
||||||
@ -74,10 +74,10 @@ class ConcatOp : public OpKernel {
|
|||||||
inputs_flat_dim0 *= input_shape.dim_size(d);
|
inputs_flat_dim0 *= input_shape.dim_size(d);
|
||||||
}
|
}
|
||||||
int output_concat_dim = 0;
|
int output_concat_dim = 0;
|
||||||
const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape);
|
const bool input_is_scalar = IsLegacyScalar(input_shape);
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
const auto in = values[i];
|
const auto in = values[i];
|
||||||
const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape());
|
const bool in_is_scalar = IsLegacyScalar(in.shape());
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
|
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -100,12 +100,12 @@ class ConcatOp : public OpKernel {
|
|||||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||||
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||||
}
|
}
|
||||||
// TODO(irving): Remove check once !kAllowLegacyScalars
|
// TODO(irving): Remove check once !allow_legacy_scalars().
|
||||||
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
|
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShape output_shape(input_shape);
|
TensorShape output_shape(input_shape);
|
||||||
// TODO(irving): Remove rank 0 case once !kAllowLegacyScalars
|
// TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
|
||||||
if (output_shape.dims() == 0) {
|
if (output_shape.dims() == 0) {
|
||||||
output_shape.AddDim(output_concat_dim);
|
output_shape.AddDim(output_concat_dim);
|
||||||
} else {
|
} else {
|
||||||
|
@ -143,11 +143,14 @@ class FillOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& Tdims = context->input(0);
|
const Tensor& Tdims = context->input(0);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("dims must be a vector of int32."));
|
context, IsLegacyVector(Tdims.shape()),
|
||||||
|
errors::InvalidArgument("dims must be a vector of int32, got shape ",
|
||||||
|
Tdims.shape().ShortDebugString()));
|
||||||
const Tensor& Tvalue = context->input(1);
|
const Tensor& Tvalue = context->input(1);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()),
|
OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
|
||||||
errors::InvalidArgument("value must be a scalar."));
|
errors::InvalidArgument("value must be a scalar, got shape ",
|
||||||
|
Tvalue.shape().ShortDebugString()));
|
||||||
auto dims = Tdims.flat<int32>();
|
auto dims = Tdims.flat<int32>();
|
||||||
for (int i = 0; i < dims.size(); i++) {
|
for (int i = 0; i < dims.size(); i++) {
|
||||||
OP_REQUIRES(context, dims(i) >= 0,
|
OP_REQUIRES(context, dims(i) >= 0,
|
||||||
|
@ -28,7 +28,7 @@ class AssertOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& cond = ctx->input(0);
|
const Tensor& cond = ctx->input(0);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
|
||||||
errors::InvalidArgument("In[0] should be a scalar: ",
|
errors::InvalidArgument("In[0] should be a scalar: ",
|
||||||
cond.shape().ShortDebugString()));
|
cond.shape().ShortDebugString()));
|
||||||
|
|
||||||
|
@ -59,7 +59,8 @@ class PadOp : public OpKernel {
|
|||||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||||
in1.shape().DebugString()));
|
in1.shape().DebugString()));
|
||||||
const int fixed_dims =
|
const int fixed_dims =
|
||||||
(kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims;
|
(allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
|
||||||
|
: dims;
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, fixed_dims == in1.dim_size(0),
|
context, fixed_dims == in1.dim_size(0),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -76,7 +77,7 @@ class PadOp : public OpKernel {
|
|||||||
errors::InvalidArgument("Paddings must be non-negative: ",
|
errors::InvalidArgument("Paddings must be non-negative: ",
|
||||||
before_d, " ", after_d));
|
before_d, " ", after_d));
|
||||||
const int size_d =
|
const int size_d =
|
||||||
(kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d);
|
(allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
|
||||||
output_shape.AddDim(before_d + size_d + after_d);
|
output_shape.AddDim(before_d + size_d + after_d);
|
||||||
}
|
}
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
@ -89,7 +90,7 @@ class PadOp : public OpKernel {
|
|||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
// TODO(irving): Once Pad doesn't need a scalar special case,
|
// TODO(irving): Once Pad doesn't need a scalar special case,
|
||||||
// change flat to tensor. That is, once !kAllowLegacyScalars.
|
// change flat to tensor. That is, once !allow_legacy_scalars().
|
||||||
Operate<1>(context, in0.flat<T>(), paddings, output);
|
Operate<1>(context, in0.flat<T>(), paddings, output);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -180,7 +180,7 @@ namespace {
|
|||||||
|
|
||||||
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
|
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
|
||||||
int index, Tensor** output) {
|
int index, Tensor** output) {
|
||||||
if (!TensorShapeUtils::IsLegacyVector(shape.shape())) {
|
if (!ctx->op_kernel().IsLegacyVector(shape.shape())) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"shape must be a vector of {int32,int64}, got shape ",
|
"shape must be a vector of {int32,int64}, got shape ",
|
||||||
shape.shape().ShortDebugString());
|
shape.shape().ShortDebugString());
|
||||||
|
@ -35,7 +35,7 @@ class ReshapeOp : public OpKernel {
|
|||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const Tensor& sizes = context->input(1);
|
const Tensor& sizes = context->input(1);
|
||||||
// Preliminary validation of sizes.
|
// Preliminary validation of sizes.
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()),
|
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
|
||||||
errors::InvalidArgument("sizes input must be 1-D, not shape ",
|
errors::InvalidArgument("sizes input must be 1-D, not shape ",
|
||||||
sizes.shape().ShortDebugString()));
|
sizes.shape().ShortDebugString()));
|
||||||
const int64 num_dims = sizes.NumElements();
|
const int64 num_dims = sizes.NumElements();
|
||||||
|
@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
static const char* input_names[3] = {"basename", "shard", "num_shards"};
|
static const char* input_names[3] = {"basename", "shard", "num_shards"};
|
||||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
input_names[i], " must be a scalar, got shape ",
|
input_names[i], " must be a scalar, got shape ",
|
||||||
ctx->input(i).shape().ShortDebugString()));
|
ctx->input(i).shape().ShortDebugString()));
|
||||||
@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
static const char* input_names[2] = {"basename", "num_shards"};
|
static const char* input_names[2] = {"basename", "num_shards"};
|
||||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
input_names[i], " must be a scalar, got shape ",
|
input_names[i], " must be a scalar, got shape ",
|
||||||
ctx->input(i).shape().ShortDebugString()));
|
ctx->input(i).shape().ShortDebugString()));
|
||||||
|
@ -184,7 +184,7 @@ class UnsortedSegmentSumOp : public OpKernel {
|
|||||||
const Tensor& num_segments = context->input(2);
|
const Tensor& num_segments = context->input(2);
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()),
|
context, IsLegacyScalar(num_segments.shape()),
|
||||||
errors::InvalidArgument("num_segments should be a scalar, not shape ",
|
errors::InvalidArgument("num_segments should be a scalar, not shape ",
|
||||||
num_segments.shape().ShortDebugString()));
|
num_segments.shape().ShortDebugString()));
|
||||||
|
|
||||||
@ -406,7 +406,7 @@ class SparseSegmentMeanGradOp : public OpKernel {
|
|||||||
errors::InvalidArgument("indices should be a vector."));
|
errors::InvalidArgument("indices should be a vector."));
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
|
||||||
errors::InvalidArgument("segment_ids should be a vector."));
|
errors::InvalidArgument("segment_ids should be a vector."));
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()),
|
OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
|
||||||
errors::InvalidArgument("output_dim0 should be a scalar."));
|
errors::InvalidArgument("output_dim0 should be a scalar."));
|
||||||
|
|
||||||
const int64 N = indices.NumElements();
|
const int64 N = indices.NumElements();
|
||||||
|
@ -34,13 +34,13 @@ class RangeOp : public OpKernel {
|
|||||||
const Tensor& start_in = context->input(0);
|
const Tensor& start_in = context->input(0);
|
||||||
const Tensor& limit_in = context->input(1);
|
const Tensor& limit_in = context->input(1);
|
||||||
const Tensor& delta_in = context->input(2);
|
const Tensor& delta_in = context->input(2);
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()),
|
OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
|
||||||
errors::InvalidArgument("start must be a scalar, not shape ",
|
errors::InvalidArgument("start must be a scalar, not shape ",
|
||||||
start_in.shape().ShortDebugString()));
|
start_in.shape().ShortDebugString()));
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()),
|
OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
|
||||||
errors::InvalidArgument("limit must be a scalar, not shape ",
|
errors::InvalidArgument("limit must be a scalar, not shape ",
|
||||||
limit_in.shape().ShortDebugString()));
|
limit_in.shape().ShortDebugString()));
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()),
|
OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
|
||||||
errors::InvalidArgument("delta must be a scalar, not shape ",
|
errors::InvalidArgument("delta must be a scalar, not shape ",
|
||||||
delta_in.shape().ShortDebugString()));
|
delta_in.shape().ShortDebugString()));
|
||||||
const int32 start = GetValue(start_in.scalar<T>()());
|
const int32 start = GetValue(start_in.scalar<T>()());
|
||||||
|
@ -69,14 +69,15 @@ static void SharedValidation(OpKernelContext* context,
|
|||||||
const Tensor& size_tensor = context->input(2);
|
const Tensor& size_tensor = context->input(2);
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) &&
|
context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
|
||||||
TensorShapeUtils::IsLegacyVector(size_tensor.shape()) &&
|
context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
|
||||||
begin_tensor.NumElements() == input.dims() &&
|
begin_tensor.NumElements() == input.dims() &&
|
||||||
size_tensor.NumElements() == input.dims(),
|
size_tensor.NumElements() == input.dims(),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Expected begin and size arguments to be 1-D tensors of size ",
|
"Expected begin and size arguments to be 1-D tensors of size ",
|
||||||
input.dims(), ", but got ", begin_tensor.NumElements(), " and ",
|
input.dims(), ", but got shapes ",
|
||||||
size_tensor.NumElements(), " instead."));
|
begin_tensor.shape().ShortDebugString(), " and ",
|
||||||
|
size_tensor.shape().ShortDebugString(), " instead."));
|
||||||
|
|
||||||
const int input_dims = input.dims();
|
const int input_dims = input.dims();
|
||||||
*begin = IntTensorToInt64Vec(begin_tensor);
|
*begin = IntTensorToInt64Vec(begin_tensor);
|
||||||
|
@ -60,7 +60,7 @@ class SparseToDense : public OpKernel {
|
|||||||
// output_shape
|
// output_shape
|
||||||
const Tensor& output_shape = c->input(1);
|
const Tensor& output_shape = c->input(1);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
c, TensorShapeUtils::IsLegacyVector(output_shape.shape()),
|
c, IsLegacyVector(output_shape.shape()),
|
||||||
errors::InvalidArgument("output_shape should be a vector, got shape ",
|
errors::InvalidArgument("output_shape should be a vector, got shape ",
|
||||||
output_shape.shape().ShortDebugString()));
|
output_shape.shape().ShortDebugString()));
|
||||||
OP_REQUIRES(c, output_shape.NumElements() == num_dims,
|
OP_REQUIRES(c, output_shape.NumElements() == num_dims,
|
||||||
|
@ -48,8 +48,8 @@ class SummaryImageOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
const Tensor& tags = c->input(0);
|
const Tensor& tags = c->input(0);
|
||||||
const Tensor& tensor = c->input(1);
|
const Tensor& tensor = c->input(1);
|
||||||
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
|
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
|
||||||
errors::InvalidArgument("Tags must have be a scalar"));
|
errors::InvalidArgument("Tags must be a scalar"));
|
||||||
OP_REQUIRES(c, tensor.dims() == 4 &&
|
OP_REQUIRES(c, tensor.dims() == 4 &&
|
||||||
(tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
|
(tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
|
||||||
tensor.dim_size(3) == 4),
|
tensor.dim_size(3) == 4),
|
||||||
|
@ -40,12 +40,12 @@ class SummaryScalarOp : public OpKernel {
|
|||||||
const Tensor& tags = c->input(0);
|
const Tensor& tags = c->input(0);
|
||||||
const Tensor& values = c->input(1);
|
const Tensor& values = c->input(1);
|
||||||
|
|
||||||
OP_REQUIRES(c, tags.IsSameSize(values) ||
|
OP_REQUIRES(c, tags.IsSameSize(values) || (IsLegacyScalar(tags.shape()) &&
|
||||||
(TensorShapeUtils::IsLegacyScalar(tags.shape()) &&
|
IsLegacyScalar(values.shape())),
|
||||||
TensorShapeUtils::IsLegacyScalar(values.shape())),
|
|
||||||
errors::InvalidArgument("tags and values not the same shape: ",
|
errors::InvalidArgument("tags and values not the same shape: ",
|
||||||
tags.shape().ShortDebugString(), " != ",
|
tags.shape().ShortDebugString(), " != ",
|
||||||
values.shape().ShortDebugString()));
|
values.shape().ShortDebugString(),
|
||||||
|
SingleTag(tags)));
|
||||||
auto Ttags = tags.flat<string>();
|
auto Ttags = tags.flat<string>();
|
||||||
auto Tvalues = values.flat<T>();
|
auto Tvalues = values.flat<T>();
|
||||||
Summary s;
|
Summary s;
|
||||||
@ -59,6 +59,15 @@ class SummaryScalarOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
|
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
|
||||||
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
|
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there's only one tag, include it in the error message
|
||||||
|
static string SingleTag(const Tensor& tags) {
|
||||||
|
if (tags.NumElements() == 1) {
|
||||||
|
return strings::StrCat(" (tag '", tags.flat<string>()(0), "')");
|
||||||
|
} else {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -72,7 +81,7 @@ class SummaryHistoOp : public OpKernel {
|
|||||||
const Tensor& tags = c->input(0);
|
const Tensor& tags = c->input(0);
|
||||||
const Tensor& values = c->input(1);
|
const Tensor& values = c->input(1);
|
||||||
const auto flat = values.flat<T>();
|
const auto flat = values.flat<T>();
|
||||||
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
|
OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
|
||||||
errors::InvalidArgument("tags must be scalar"));
|
errors::InvalidArgument("tags must be scalar"));
|
||||||
// Build histogram of values in "values" tensor
|
// Build histogram of values in "values" tensor
|
||||||
histogram::Histogram histo;
|
histogram::Histogram histo;
|
||||||
|
@ -46,7 +46,7 @@ class TileOp : public OpKernel {
|
|||||||
const Tensor& multiples = context->input(1);
|
const Tensor& multiples = context->input(1);
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
|
context, IsLegacyVector(multiples.shape()),
|
||||||
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
||||||
multiples.shape().ShortDebugString()));
|
multiples.shape().ShortDebugString()));
|
||||||
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
||||||
@ -192,7 +192,7 @@ class TileGradientOp : public OpKernel {
|
|||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
const Tensor& multiples = context->input(1);
|
const Tensor& multiples = context->input(1);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
|
context, IsLegacyVector(multiples.shape()),
|
||||||
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
|
||||||
multiples.shape().ShortDebugString()));
|
multiples.shape().ShortDebugString()));
|
||||||
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
|
||||||
|
@ -153,7 +153,7 @@ class ApplyGradientDescentOp : public OpKernel {
|
|||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
"Attempting to use uninitialized variables: ", def().input(0)));
|
"Attempting to use uninitialized variables: ", def().input(0)));
|
||||||
const Tensor& alpha = ctx->input(1);
|
const Tensor& alpha = ctx->input(1);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
|
||||||
errors::InvalidArgument("alpha is not a scalar: ",
|
errors::InvalidArgument("alpha is not a scalar: ",
|
||||||
alpha.shape().DebugString()));
|
alpha.shape().DebugString()));
|
||||||
const Tensor& delta = ctx->input(2);
|
const Tensor& delta = ctx->input(2);
|
||||||
@ -242,7 +242,7 @@ class ApplyAdagradOp : public OpKernel {
|
|||||||
errors::FailedPrecondition(
|
errors::FailedPrecondition(
|
||||||
"Attempting to use uninitialized variables: ", def().input(1)));
|
"Attempting to use uninitialized variables: ", def().input(1)));
|
||||||
const Tensor& lr = ctx->input(2);
|
const Tensor& lr = ctx->input(2);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||||
errors::InvalidArgument("lr is not a scalar: ",
|
errors::InvalidArgument("lr is not a scalar: ",
|
||||||
lr.shape().DebugString()));
|
lr.shape().DebugString()));
|
||||||
const Tensor& grad = ctx->input(3);
|
const Tensor& grad = ctx->input(3);
|
||||||
@ -336,7 +336,7 @@ class SparseApplyAdagradOp : public OpKernel {
|
|||||||
errors::InvalidArgument("var must be at least 1 dimensional"));
|
errors::InvalidArgument("var must be at least 1 dimensional"));
|
||||||
|
|
||||||
const Tensor& lr = ctx->input(2);
|
const Tensor& lr = ctx->input(2);
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
|
OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
|
||||||
errors::InvalidArgument("lr is not a scalar: ",
|
errors::InvalidArgument("lr is not a scalar: ",
|
||||||
lr.shape().DebugString()));
|
lr.shape().DebugString()));
|
||||||
const Tensor& grad = ctx->input(3);
|
const Tensor& grad = ctx->input(3);
|
||||||
|
@ -683,6 +683,7 @@ py_library(
|
|||||||
"ops/image_ops.py",
|
"ops/image_ops.py",
|
||||||
"ops/init_ops.py",
|
"ops/init_ops.py",
|
||||||
"ops/io_ops.py",
|
"ops/io_ops.py",
|
||||||
|
"ops/learn.py",
|
||||||
"ops/linalg_grad.py",
|
"ops/linalg_grad.py",
|
||||||
"ops/linalg_ops.py",
|
"ops/linalg_ops.py",
|
||||||
"ops/logging_ops.py",
|
"ops/logging_ops.py",
|
||||||
|
@ -57,7 +57,8 @@ from tensorflow.python.client.client_lib import *
|
|||||||
# Ops
|
# Ops
|
||||||
from tensorflow.python.ops.standard_ops import *
|
from tensorflow.python.ops.standard_ops import *
|
||||||
|
|
||||||
# Bring nn, image_ops, user_ops, compat as a subpackages
|
# Bring learn, nn, image_ops, user_ops, compat as a subpackages
|
||||||
|
from tensorflow.python.ops import learn
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import image_ops as image
|
from tensorflow.python.ops import image_ops as image
|
||||||
from tensorflow.python.user_ops import user_ops
|
from tensorflow.python.user_ops import user_ops
|
||||||
@ -77,7 +78,7 @@ from tensorflow.python.platform import resource_loader
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
# Don't export modules except for the few we really want
|
# Don't export modules except for the few we really want
|
||||||
_whitelist = set([app, compat, errors, flags, image, logging, nn,
|
_whitelist = set([app, compat, errors, flags, image, learn, logging, nn,
|
||||||
python_io, resource_loader, test, train, user_ops])
|
python_io, resource_loader, test, train, user_ops])
|
||||||
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
|
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
|
||||||
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable
|
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable
|
||||||
|
@ -3159,6 +3159,8 @@ class GraphKeys(object):
|
|||||||
keep moving averages. See
|
keep moving averages. See
|
||||||
[`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables)
|
[`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables)
|
||||||
for more details.
|
for more details.
|
||||||
|
* `REGULARIZATION_LOSSES`: regularization losses collected during graph
|
||||||
|
construction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Key to collect Variable objects that must be saved and restored
|
# Key to collect Variable objects that must be saved and restored
|
||||||
@ -3178,6 +3180,8 @@ class GraphKeys(object):
|
|||||||
ASSET_FILEPATHS = "asset_filepaths"
|
ASSET_FILEPATHS = "asset_filepaths"
|
||||||
# Key to collect Variable objects that keep moving averages.
|
# Key to collect Variable objects that keep moving averages.
|
||||||
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
|
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
|
||||||
|
# Key to collected regularization losses at graph construction.
|
||||||
|
REGULARIZATION_LOSSES = "regularization_losses"
|
||||||
|
|
||||||
|
|
||||||
def add_to_collection(name, value):
|
def add_to_collection(name, value):
|
||||||
|
225
tensorflow/python/kernel_tests/learn_test.py
Normal file
225
tensorflow/python/kernel_tests/learn_test.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for tf.learn."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import tensorflow.python.platform # pylint: disable=unused-import,g-bad-import-order
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
|
||||||
|
|
||||||
|
class FullyConnectedTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
tf.test.TestCase.setUp(self)
|
||||||
|
tf.set_random_seed(1234)
|
||||||
|
self.input = tf.constant([[1., 2., 3.], [-4., 5., -6.]])
|
||||||
|
assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
|
||||||
|
|
||||||
|
def assert_summary_scope(self, regexp):
|
||||||
|
for summary in tf.get_collection(tf.GraphKeys.SUMMARIES):
|
||||||
|
tag = tensor_util.ConstantValue(summary.op.inputs[0])
|
||||||
|
assert tag is not None, 'All summaries have constant tags'
|
||||||
|
tag = str(tag)
|
||||||
|
assert isinstance(tag[0], six.string_types), tag[0]
|
||||||
|
assert re.match(regexp, tag), "tag doesn't match %s: %s" % (regexp, tag)
|
||||||
|
|
||||||
|
def test_basic_use(self):
|
||||||
|
output = tf.learn.fully_connected(self.input, 8, activation_fn=tf.nn.relu)
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
with self.assertRaises(tf.errors.FailedPreconditionError):
|
||||||
|
sess.run(output)
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
out_value = sess.run(output)
|
||||||
|
|
||||||
|
self.assertEqual(output.get_shape().as_list(), [2, 8])
|
||||||
|
self.assertTrue(np.all(out_value >= 0),
|
||||||
|
'Relu should have capped all values.')
|
||||||
|
|
||||||
|
self.assertGreater(tf.get_collection(tf.GraphKeys.SUMMARIES), 0,
|
||||||
|
'Some summaries should have been added.')
|
||||||
|
self.assertEqual(2,
|
||||||
|
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
|
||||||
|
self.assertEqual(0,
|
||||||
|
len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
|
||||||
|
self.assert_summary_scope('fully_connected')
|
||||||
|
|
||||||
|
def test_variable_reuse_with_scope(self):
|
||||||
|
with tf.variable_scope('test') as vs:
|
||||||
|
output1 = tf.learn.fully_connected(self.input,
|
||||||
|
8,
|
||||||
|
activation_fn=tf.nn.relu)
|
||||||
|
output2 = tf.learn.fully_connected(self.input,
|
||||||
|
8,
|
||||||
|
activation_fn=tf.nn.relu)
|
||||||
|
|
||||||
|
with tf.variable_scope(vs, reuse=True):
|
||||||
|
output3 = tf.learn.fully_connected(self.input,
|
||||||
|
8,
|
||||||
|
activation_fn=tf.nn.relu)
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
|
||||||
|
|
||||||
|
self.assertFalse(np.allclose(out_value1, out_value2))
|
||||||
|
self.assertAllClose(out_value1, out_value3)
|
||||||
|
|
||||||
|
def test_variable_reuse_with_template(self):
|
||||||
|
tmpl1 = tf.make_template('test',
|
||||||
|
tf.learn.fully_connected,
|
||||||
|
num_output_nodes=8)
|
||||||
|
output1 = tmpl1(self.input)
|
||||||
|
output2 = tmpl1(self.input)
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
out_value1, out_value2 = sess.run([output1, output2])
|
||||||
|
self.assertAllClose(out_value1, out_value2)
|
||||||
|
self.assert_summary_scope(r'test(_\d)?/fully_connected')
|
||||||
|
|
||||||
|
def test_custom_initializers(self):
|
||||||
|
output = tf.learn.fully_connected(self.input,
|
||||||
|
2,
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
weight_init=tf.constant_initializer(2.0),
|
||||||
|
bias_init=tf.constant_initializer(1.0))
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
out_value = sess.run(output)
|
||||||
|
|
||||||
|
self.assertAllClose(np.array([[13.0, 13.0], [0.0, 0.0]]), out_value)
|
||||||
|
|
||||||
|
def test_custom_collections(self):
|
||||||
|
tf.learn.fully_connected(self.input,
|
||||||
|
2,
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
weight_collections=['unbiased'],
|
||||||
|
bias_collections=['biased'])
|
||||||
|
|
||||||
|
self.assertEquals(1, len(tf.get_collection('unbiased')))
|
||||||
|
self.assertEquals(1, len(tf.get_collection('biased')))
|
||||||
|
|
||||||
|
def test_all_custom_collections(self):
|
||||||
|
tf.learn.fully_connected(self.input,
|
||||||
|
2,
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
weight_collections=['unbiased', 'all'],
|
||||||
|
bias_collections=['biased', 'all'])
|
||||||
|
|
||||||
|
self.assertEquals(1, len(tf.get_collection('unbiased')))
|
||||||
|
self.assertEquals(1, len(tf.get_collection('biased')))
|
||||||
|
self.assertEquals(2, len(tf.get_collection('all')))
|
||||||
|
self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
|
||||||
|
tf.get_collection('all'))
|
||||||
|
|
||||||
|
def test_no_summaries(self):
|
||||||
|
tf.learn.fully_connected(self.input,
|
||||||
|
2,
|
||||||
|
activation_fn=tf.nn.relu,
|
||||||
|
create_summaries=False)
|
||||||
|
self.assertEquals([], tf.get_collection(tf.GraphKeys.SUMMARIES))
|
||||||
|
|
||||||
|
def test_regularizer(self):
|
||||||
|
cnt = [0]
|
||||||
|
tensor = tf.constant(5.0)
|
||||||
|
def test_fn(_):
|
||||||
|
cnt[0] += 1
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
tf.learn.fully_connected(self.input, 2, weight_regularizer=test_fn)
|
||||||
|
|
||||||
|
self.assertEqual([tensor],
|
||||||
|
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
|
||||||
|
self.assertEqual(1, cnt[0])
|
||||||
|
|
||||||
|
def test_shape_enforcement(self):
|
||||||
|
place = tf.placeholder(tf.float32)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.fully_connected(place, 8)
|
||||||
|
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
|
||||||
|
|
||||||
|
place.set_shape([None, None])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.fully_connected(place, 8)
|
||||||
|
tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
|
||||||
|
|
||||||
|
place.set_shape([None, 6])
|
||||||
|
tf.learn.fully_connected(place, 8) # No error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.fully_connected(place, 8, num_input_nodes=5)
|
||||||
|
|
||||||
|
place = tf.placeholder(tf.float32)
|
||||||
|
place.set_shape([2, 6, 5])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.fully_connected(place, 8)
|
||||||
|
|
||||||
|
def test_no_bias(self):
|
||||||
|
tf.learn.fully_connected(self.input, 2, bias_init=None)
|
||||||
|
|
||||||
|
self.assertEqual(1,
|
||||||
|
len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
|
||||||
|
|
||||||
|
|
||||||
|
class RegularizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_l1(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l1_regularizer(2.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l1_regularizer(-1.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l1_regularizer(0)
|
||||||
|
|
||||||
|
self.assertIsNone(tf.learn.l1_regularizer(0.)(None))
|
||||||
|
|
||||||
|
values = np.array([1., -1., 4., 2.])
|
||||||
|
weights = tf.constant(values)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
result = sess.run(tf.learn.l1_regularizer(.5)(weights))
|
||||||
|
|
||||||
|
self.assertAllClose(np.abs(values).sum() * .5, result)
|
||||||
|
|
||||||
|
def test_l2(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l2_regularizer(2.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l2_regularizer(-1.)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tf.learn.l2_regularizer(0)
|
||||||
|
|
||||||
|
self.assertIsNone(tf.learn.l2_regularizer(0.)(None))
|
||||||
|
|
||||||
|
values = np.array([1., -1., 4., 2.])
|
||||||
|
weights = tf.constant(values)
|
||||||
|
with tf.Session() as sess:
|
||||||
|
result = sess.run(tf.learn.l2_regularizer(.42)(weights))
|
||||||
|
|
||||||
|
self.assertAllClose(np.power(values, 2).sum() / 2.0 * .42, result)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
359
tensorflow/python/ops/learn.py
Normal file
359
tensorflow/python/ops/learn.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# pylint: disable=g-short-docstring-punctuation
|
||||||
|
"""## Higher level ops related to regularization and building layers.
|
||||||
|
|
||||||
|
This package provides several ops that take care of creating variables that are
|
||||||
|
used internally in a consistent way and provide the building blocks for many
|
||||||
|
common machine learning algorithms.
|
||||||
|
|
||||||
|
@@fully_connected
|
||||||
|
|
||||||
|
## Regularizers
|
||||||
|
|
||||||
|
Regularization can help prevent overfitting.
|
||||||
|
These have the signature `fn(weights)`. The loss is typically added to
|
||||||
|
`tf.GraphKeys.REGULARIZATION_LOSS`
|
||||||
|
|
||||||
|
@@l1_regularizer
|
||||||
|
@@l2_regularizer
|
||||||
|
|
||||||
|
## Initializations
|
||||||
|
|
||||||
|
This also includes a common initialization for connecting multiple layers.
|
||||||
|
|
||||||
|
@@xavier_initializer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import standard_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.platform import logging
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['xavier_initializer', 'fully_connected', 'l1_regularizer',
|
||||||
|
'l2_regularizer']
|
||||||
|
|
||||||
|
|
||||||
|
def xavier_initializer(n_inputs, n_outputs, uniform=True):
|
||||||
|
"""Set the parameter initialization using the method described in paper.
|
||||||
|
|
||||||
|
Xavier Glorot and Yoshua Bengio (2010):
|
||||||
|
Understanding the difficulty of training deep feedforward neural
|
||||||
|
networks. International conference on artificial intelligence and
|
||||||
|
statistics.
|
||||||
|
|
||||||
|
This method is designed to keep the scale of the gradients roughly the same
|
||||||
|
in all layers. In uniform distribution this ends up being the range:
|
||||||
|
`x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard
|
||||||
|
deviation of `sqrt(3. / (in + out))` is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_inputs: The number of input nodes into each output.
|
||||||
|
n_outputs: The number of output nodes for each input.
|
||||||
|
uniform: If true use a uniform distribution, otherwise use a truncated
|
||||||
|
normal.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initializer.
|
||||||
|
"""
|
||||||
|
if uniform:
|
||||||
|
# 6 was used in the paper.
|
||||||
|
init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
|
||||||
|
return standard_ops.random_uniform_initializer(-init_range, init_range)
|
||||||
|
else:
|
||||||
|
# 3 gives us approximately the same limits as above since this repicks
|
||||||
|
# values greater than 2 standard deviations from the mean.
|
||||||
|
stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
|
||||||
|
return standard_ops.truncated_normal_initializer(stddev=stddev)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_summary_tag_unique(tag):
|
||||||
|
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
|
||||||
|
old_tag = tensor_util.ConstantValue(summary.op.inputs[0])
|
||||||
|
if tag == str(old_tag):
|
||||||
|
raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
|
||||||
|
(tag, summary, old_tag))
|
||||||
|
|
||||||
|
|
||||||
|
def _add_scalar_summary(tensor, tag=None):
|
||||||
|
"""Add a summary operation for the tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to summarize.
|
||||||
|
tag: The tag to use, if None then use tensor's op's name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created histogram summary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the tag is already in use or the rank is not 0.
|
||||||
|
"""
|
||||||
|
tensor.get_shape().assert_has_rank(0)
|
||||||
|
tag = tag or tensor.op.name
|
||||||
|
_assert_summary_tag_unique(tag)
|
||||||
|
return standard_ops.scalar_summary(tag, tensor, name='%s_summary' % tag)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_histogram_summary(tensor, tag=None):
|
||||||
|
"""Add a summary operation for the histogram of a tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The tensor to summarize.
|
||||||
|
tag: The tag to use, if None then use tensor's op's name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created histogram summary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the tag is already in use.
|
||||||
|
"""
|
||||||
|
# TODO(opensource): A global or scoped mechanism to disable summaries.
|
||||||
|
tag = tag or tensor.op.name
|
||||||
|
_assert_summary_tag_unique(tag)
|
||||||
|
return standard_ops.histogram_summary(tag, tensor, name='%s_summary' % tag)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_activation_with_summaries(x, activation_fn):
|
||||||
|
"""Returns activation_fn(x).
|
||||||
|
|
||||||
|
This applies the given activation and adds useful summaries specific to the
|
||||||
|
activation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The tensor to apply activation to.
|
||||||
|
activation_fn: An activation function.
|
||||||
|
Returns:
|
||||||
|
A tensor with activation applied to x.
|
||||||
|
"""
|
||||||
|
if activation_fn is None:
|
||||||
|
return x
|
||||||
|
y = activation_fn(x)
|
||||||
|
if activation_fn in (nn.relu, nn.softplus, nn.relu6):
|
||||||
|
# Using x for comparison to avoid floating point equality and/or epsilons.
|
||||||
|
_add_scalar_summary(
|
||||||
|
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.less(
|
||||||
|
x, 0.0))), '%s/zeros' % y.op.name)
|
||||||
|
if activation_fn is nn.relu6:
|
||||||
|
_add_scalar_summary(
|
||||||
|
standard_ops.reduce_mean(standard_ops.to_float(standard_ops.greater(
|
||||||
|
x, 6.0))), '%s/sixes' % y.op.name)
|
||||||
|
if activation_fn is nn.l2_normalize:
|
||||||
|
_add_scalar_summary(
|
||||||
|
standard_ops.reduce_mean(standard_ops.sqrt(standard_ops.sum(
|
||||||
|
standard_ops.square(x), 1))), '%s/length' % y.op.name)
|
||||||
|
_add_histogram_summary(y, '%s/activations' % y.op.name)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_regularization(w, regularizer):
|
||||||
|
loss = regularizer(w)
|
||||||
|
if loss:
|
||||||
|
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
|
||||||
|
|
||||||
|
|
||||||
|
def l1_regularizer(scale):
|
||||||
|
"""Returns a function that can be used to apply L1 regularization to weights.
|
||||||
|
|
||||||
|
L1 regularization encourages sparsity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function with signature `l1(weights, name=None)` that apply L1
|
||||||
|
regularization.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
|
||||||
|
float.
|
||||||
|
"""
|
||||||
|
if isinstance(scale, numbers.Integral):
|
||||||
|
raise ValueError('scale cannot be an integer: %s' % scale)
|
||||||
|
if isinstance(scale, numbers.Real):
|
||||||
|
if scale < 0.:
|
||||||
|
raise ValueError('Setting a scale less than 0 on a regularizer: %g' %
|
||||||
|
scale)
|
||||||
|
if scale >= 1.:
|
||||||
|
raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
|
||||||
|
scale)
|
||||||
|
if scale == 0.:
|
||||||
|
logging.info('Scale of 0 disables regularizer.')
|
||||||
|
return lambda _, name=None: None
|
||||||
|
def l1(weights, name=None):
|
||||||
|
"""Applies L1 regularization to weights."""
|
||||||
|
with ops.op_scope([weights], name, 'l1_regularizer') as scope:
|
||||||
|
my_scale = ops.convert_to_tensor(scale,
|
||||||
|
dtype=weights.dtype.base_dtype,
|
||||||
|
name='scale')
|
||||||
|
return standard_ops.mul(
|
||||||
|
my_scale,
|
||||||
|
standard_ops.reduce_sum(standard_ops.abs(weights)),
|
||||||
|
name=scope)
|
||||||
|
return l1
|
||||||
|
|
||||||
|
|
||||||
|
def l2_regularizer(scale):
|
||||||
|
"""Returns a function that can be used to apply L2 regularization to weights.
|
||||||
|
|
||||||
|
Small values of L2 can help prevent overfitting the training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function with signature `l2(weights, name=None)` that applies L2
|
||||||
|
regularization.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
|
||||||
|
float.
|
||||||
|
"""
|
||||||
|
if isinstance(scale, numbers.Integral):
|
||||||
|
raise ValueError('scale cannot be an integer: %s' % (scale,))
|
||||||
|
if isinstance(scale, numbers.Real):
|
||||||
|
if scale < 0.:
|
||||||
|
raise ValueError('Setting a scale less than 0 on a regularizer: %g.' %
|
||||||
|
scale)
|
||||||
|
if scale >= 1.:
|
||||||
|
raise ValueError('Setting a scale greater than 1 on a regularizer: %g.' %
|
||||||
|
scale)
|
||||||
|
if scale == 0.:
|
||||||
|
logging.info('Scale of 0 disables regularizer.')
|
||||||
|
return lambda _, name=None: None
|
||||||
|
def l2(weights, name=None):
|
||||||
|
"""Applies l2 regularization to weights."""
|
||||||
|
with ops.op_scope([weights], name, 'l2_regularizer') as scope:
|
||||||
|
my_scale = ops.convert_to_tensor(scale,
|
||||||
|
dtype=weights.dtype.base_dtype,
|
||||||
|
name='scale')
|
||||||
|
return standard_ops.mul(my_scale, nn.l2_loss(weights), name=scope)
|
||||||
|
return l2
|
||||||
|
|
||||||
|
|
||||||
|
def fully_connected(x,
|
||||||
|
num_output_nodes,
|
||||||
|
activation_fn=None,
|
||||||
|
weight_init=None,
|
||||||
|
bias_init=standard_ops.constant_initializer(0.),
|
||||||
|
num_input_nodes=None,
|
||||||
|
name=None,
|
||||||
|
weight_collections=None,
|
||||||
|
bias_collections=None,
|
||||||
|
weight_regularizer=None,
|
||||||
|
create_summaries=True):
|
||||||
|
"""Adds the parameters for a fully connected layer and returns the output.
|
||||||
|
|
||||||
|
A fully connected layer is generally defined as a matrix multiply:
|
||||||
|
\\\\(y = f(w * x + b)\\\\) where **f** is given by `activation_fn`
|
||||||
|
|
||||||
|
This op creates `w` and optionally `b` (disable with `bias_init=None`) and
|
||||||
|
adds various summaries that can be useful for visualizing learning or
|
||||||
|
diagnosing training problems. The variable creation is compatible with
|
||||||
|
`tf.variable_scope` and so can be reused with `tf.variable_scope` or
|
||||||
|
`tf.make_template`.
|
||||||
|
|
||||||
|
In almost all cases, the number of input nodes can be inferred from the shape
|
||||||
|
of `x`, but if it is unspecified or additional size checks are desired, then
|
||||||
|
`num_input_nodes` can be specified.
|
||||||
|
|
||||||
|
Most of the details of variable creation can be controlled by specifying the
|
||||||
|
initializers (`weight_init` and `bias_init`) and which collections to place
|
||||||
|
the created variables in (`weight_collections` and `bias_collections`).
|
||||||
|
|
||||||
|
A per layer regularization can be specified by setting `weight_regularizer`.
|
||||||
|
This is only applied to weights and not the bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor.
|
||||||
|
num_output_nodes: The size of the output.
|
||||||
|
activation_fn: A function that requires a single Tensor that is applied as a
|
||||||
|
non-linearity. If None is used, then this is a linear layer.
|
||||||
|
weight_init: An optional initialization. If not specified, uses Xavier
|
||||||
|
initialization (see `tf.learn.xavier_initializer`).
|
||||||
|
bias_init: An initializer for the bias, defaults to 0.
|
||||||
|
num_input_nodes: The number of input nodes.
|
||||||
|
name: The name for this operation is used to name operations and to find
|
||||||
|
variables. If specified it must be unique for this scope, otherwise a
|
||||||
|
unique name starting with "fully_connected" will be created. See
|
||||||
|
`tf.variable_op_scope` for details.
|
||||||
|
weight_collections: List of graph collections for just weights.
|
||||||
|
bias_collections: List of graph collections for just bias.
|
||||||
|
weight_regularizer: A regularizer like the result of
|
||||||
|
`tf.learn.l1_regularizer` or `tf.learn.l2_regularizer`.
|
||||||
|
create_summaries: Set to false to disable summaries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of applying a fully connected layer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `x` is not rank 2; or `x`'s second dimension is not known
|
||||||
|
and `num_input_nodes` is not specified.
|
||||||
|
"""
|
||||||
|
with variable_scope.variable_op_scope([x], name, 'fully_connected') as vs:
|
||||||
|
# Check rank and if num_input_nodes is specified, make sure it matches.
|
||||||
|
x.get_shape().assert_is_compatible_with([None, num_input_nodes])
|
||||||
|
|
||||||
|
if not num_input_nodes:
|
||||||
|
if x.get_shape().dims is None or x.get_shape().dims[1].value is None:
|
||||||
|
raise ValueError(
|
||||||
|
'If x has an unknown first dimension then num_input_nodes '
|
||||||
|
'must be specified; shape: %s num_input_nodes: %s'
|
||||||
|
% (x.get_shape(), num_input_nodes))
|
||||||
|
else:
|
||||||
|
num_input_nodes = x.get_shape().dims[1].value
|
||||||
|
|
||||||
|
weight_init = weight_init or xavier_initializer(
|
||||||
|
num_input_nodes, num_output_nodes)
|
||||||
|
|
||||||
|
dtype = x.dtype
|
||||||
|
w = variable_scope.get_variable('weights',
|
||||||
|
shape=[num_input_nodes, num_output_nodes],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=weight_init,
|
||||||
|
collections=weight_collections)
|
||||||
|
|
||||||
|
if not vs.reuse and create_summaries:
|
||||||
|
_add_histogram_summary(w)
|
||||||
|
|
||||||
|
y = standard_ops.matmul(x, w)
|
||||||
|
# Regularization is only applied to the weights and not bias.
|
||||||
|
if weight_regularizer:
|
||||||
|
_apply_regularization(w, weight_regularizer)
|
||||||
|
if bias_init:
|
||||||
|
b = variable_scope.get_variable('bias',
|
||||||
|
shape=[num_output_nodes],
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=bias_init,
|
||||||
|
collections=bias_collections)
|
||||||
|
if not vs.reuse and create_summaries:
|
||||||
|
_add_histogram_summary(b)
|
||||||
|
|
||||||
|
y = nn.bias_add(y, b)
|
||||||
|
|
||||||
|
if create_summaries:
|
||||||
|
return _apply_activation_with_summaries(y, activation_fn)
|
||||||
|
else:
|
||||||
|
return activation_fn(y)
|
@ -559,6 +559,7 @@ class OpDefLibrary(object):
|
|||||||
"less than minimum %d." %
|
"less than minimum %d." %
|
||||||
(key, op_type_name, len(value),
|
(key, op_type_name, len(value),
|
||||||
attr_def.minimum))
|
attr_def.minimum))
|
||||||
|
attr_value.list.SetInParent()
|
||||||
if attr_def.type == "string":
|
if attr_def.type == "string":
|
||||||
attr_value.s = _MakeStr(value, key)
|
attr_value.s = _MakeStr(value, key)
|
||||||
if attr_def.HasField("allowed_values"):
|
if attr_def.HasField("allowed_values"):
|
||||||
|
Loading…
Reference in New Issue
Block a user