TensorFlow: merge changes from internal
Change 110055925 Clean up interface for adjust_contrast and adjust_brightness. - Simplify kernel for adjust_contrast and remove all min/max and casts. - Change semantics of delta arg to adjust_brightness (always in [0,1)), and adjust users. - Add saturate_cast for casting images without over/underflow problems. - Add new numbers for adjust_contrast benchmark. This CL makes two changes to the public API: - It changes the semantics of the delta parameter of adjust_brightness, which was in the same range as the input image before, and now is always in [0,1). - It changes the semantics of adjust_contrast (the cc op), which wasn't hidden, but was shadowed by the python wrapper in image_ops. It's a little questionable whether this function was part of the public API. It definitely shouldn't have been. It is now hidden, although now it could be part of the public API, albeit with a different name. Change 110054427 update ci_build * add PYTHON_BIN_PATH and always run ./configure in ci_build * rename ci_build cache directory to bazel-ci_build-cache * sync ci_build/Dockerfile.cpu with docker/Dockerfile.devel * use "FROM nvidia/cuda:..." for gpu container * therefore no need of the tensorflow_extra_deps directory anymore * share install code between containers using ./install/*.sh scripts * do not inherit (and override FROM clausule in dockerfiles anymore) * print bazel test errors to stderr Change 110047126 Update ops.pbtxt. Change 110046428 Simplify the example for the Fill op. Base CL: 110056265
This commit is contained in:
parent
0cf264b756
commit
10e62dc1e0
13
RELEASE.md
13
RELEASE.md
@ -1,3 +1,16 @@
|
||||
# Changes since last release
|
||||
|
||||
## Breaking changes to the API
|
||||
|
||||
* `AdjustContrast` kernel deprecated, new kernel `AdjustContrastv2` takes and
|
||||
outputs float only. `adjust_contrast` now takes all data types.
|
||||
* `adjust_brightness`'s `delta` argument is now always assumed to be in `[0,1]`
|
||||
(as is the norm for images in floating point formats), independent of the
|
||||
data type of the input image.
|
||||
* The image processing ops do not take `min` and `max` inputs any more, casting
|
||||
safety is handled by `saturate_cast`, which makes sure over- and underflows
|
||||
are handled before casting to data types with smaller ranges.
|
||||
|
||||
# Release 0.6.0
|
||||
|
||||
## Major Features and Improvements
|
||||
|
@ -20,6 +20,7 @@ message GraphDef {
|
||||
//
|
||||
// 0. Graphs created before GraphDef versioning
|
||||
// 1. First real version (2dec2015)
|
||||
// 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015)
|
||||
//
|
||||
// The GraphDef version is distinct from the TensorFlow version.
|
||||
// Each released version of TensorFlow will support a range of
|
||||
|
@ -32,11 +32,14 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// AdjustContrastOp is deprecated as of GraphDef version >= 2
|
||||
|
||||
template <typename Device, typename T>
|
||||
class AdjustContrastOp : public OpKernel {
|
||||
public:
|
||||
explicit AdjustContrastOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_DEPRECATED(context, 2);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
@ -133,4 +136,66 @@ REGISTER_GPU_KERNEL(double);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
template <typename Device>
|
||||
class AdjustContrastOpv2 : public OpKernel {
|
||||
public:
|
||||
explicit AdjustContrastOpv2(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
const Tensor& factor = context->input(1);
|
||||
OP_REQUIRES(context, input.dims() >= 3,
|
||||
errors::InvalidArgument("input must be at least 3-D, got shape",
|
||||
input.shape().ShortDebugString()));
|
||||
const int64 height = input.dim_size(input.dims() - 3);
|
||||
const int64 width = input.dim_size(input.dims() - 2);
|
||||
const int64 channels = input.dim_size(input.dims() - 1);
|
||||
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor.shape()),
|
||||
errors::InvalidArgument("contrast_factor must be scalar: ",
|
||||
factor.shape().ShortDebugString()));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
|
||||
Tensor mean_values;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
|
||||
TensorShape(input.shape()),
|
||||
&mean_values));
|
||||
|
||||
if (input.NumElements() > 0) {
|
||||
const int64 batch = input.NumElements() / (height * width * channels);
|
||||
const int64 shape[4] = {batch, height, width, channels};
|
||||
functor::AdjustContrastv2<Device>()(
|
||||
context->eigen_device<Device>(), input.shaped<float, 4>(shape),
|
||||
factor.scalar<float>(), mean_values.shaped<float, 4>(shape),
|
||||
output->shaped<float, 4>(shape));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_CPU),
|
||||
AdjustContrastOpv2<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the function specializations for GPU (to prevent
|
||||
// building the GPU versions here, they will be built compiling _gpu.cu.cc).
|
||||
namespace functor {
|
||||
template <>
|
||||
void AdjustContrastv2<GPUDevice>::operator()(
|
||||
const GPUDevice& d, typename TTypes<float, 4>::ConstTensor input,
|
||||
typename TTypes<float>::ConstScalar contrast_factor,
|
||||
typename TTypes<float, 4>::Tensor mean_values,
|
||||
typename TTypes<float, 4>::Tensor output);
|
||||
extern template struct AdjustContrastv2<GPUDevice>;
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU),
|
||||
AdjustContrastOpv2<GPUDevice>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -73,6 +73,51 @@ struct AdjustContrast {
|
||||
}
|
||||
};
|
||||
|
||||
// Functor used by AdjustContrastOpv2 to do the computations.
|
||||
template <typename Device>
|
||||
struct AdjustContrastv2 {
|
||||
void operator()(const Device& d, typename TTypes<float, 4>::ConstTensor input,
|
||||
typename TTypes<float>::ConstScalar contrast_factor,
|
||||
typename TTypes<float, 4>::Tensor mean_values,
|
||||
typename TTypes<float, 4>::Tensor output) {
|
||||
const int batch = input.dimension(0);
|
||||
const int height = input.dimension(1);
|
||||
const int width = input.dimension(2);
|
||||
const int channels = input.dimension(3);
|
||||
|
||||
Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::array<int, 2> reduction_axis{{1, 2}};
|
||||
Eigen::array<int, 4> scalar{{1, 1, 1, 1}};
|
||||
Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
|
||||
Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >
|
||||
reduction_axis;
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<1>,
|
||||
Eigen::type2index<1>, Eigen::type2index<1> >
|
||||
scalar;
|
||||
Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
|
||||
broadcast_dims;
|
||||
broadcast_dims.set(1, height);
|
||||
broadcast_dims.set(2, width);
|
||||
Eigen::IndexList<int, Eigen::type2index<1>, Eigen::type2index<1>, int>
|
||||
reshape_dims;
|
||||
reshape_dims.set(0, batch);
|
||||
reshape_dims.set(3, channels);
|
||||
#endif
|
||||
mean_values.device(d) = input.mean(reduction_axis)
|
||||
.eval()
|
||||
.reshape(reshape_dims)
|
||||
.broadcast(broadcast_dims);
|
||||
auto contrast_factor_tensor =
|
||||
contrast_factor.reshape(scalar).broadcast(scalar_broadcast);
|
||||
auto adjusted =
|
||||
(input - mean_values) * contrast_factor_tensor + mean_values;
|
||||
output.device(d) = adjusted;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -27,17 +27,11 @@ static Graph* BM_AdjustContrast(int batches, int width, int height) {
|
||||
in.flat<uint8>().setRandom();
|
||||
Tensor factor(DT_FLOAT, TensorShape({}));
|
||||
factor.flat<float>().setConstant(1.2);
|
||||
Tensor min_value(DT_FLOAT, TensorShape({}));
|
||||
min_value.flat<float>().setConstant(7.);
|
||||
Tensor max_value(DT_FLOAT, TensorShape({}));
|
||||
max_value.flat<float>().setConstant(250.);
|
||||
|
||||
Node* ret;
|
||||
NodeBuilder(g->NewName("n"), "AdjustContrast")
|
||||
NodeBuilder(g->NewName("n"), "AdjustContrastv2")
|
||||
.Input(test::graph::Constant(g, in))
|
||||
.Input(test::graph::Constant(g, factor))
|
||||
.Input(test::graph::Constant(g, min_value))
|
||||
.Input(test::graph::Constant(g, max_value))
|
||||
.Finalize(g, &ret);
|
||||
return g;
|
||||
}
|
||||
@ -47,12 +41,21 @@ static Graph* BM_AdjustContrast(int batches, int width, int height) {
|
||||
testing::ItemsProcessed(iters* B* W* H * 3); \
|
||||
test::Benchmark(#DEVICE, BM_AdjustContrast(B, W, H)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_AdjustContrast_##DEVICE##_##B##_##W##_##H);
|
||||
BENCHMARK(BM_AdjustContrast_##DEVICE##_##B##_##W##_##H)
|
||||
|
||||
// Benchmark results as of cl/106323955
|
||||
// BM_AdjustContrast_cpu_1_299_299 3416770 22008951 100 11.6M items/s
|
||||
|
||||
// BM_AdjustContrast_gpu_32_299_299 37117844 45512374 100 179.8M items/s
|
||||
BM_AdjustContrastDev(cpu, 1, 299, 299) BM_AdjustContrastDev(gpu, 32, 299, 299)
|
||||
// Benchmark results as of cl/109478777
|
||||
// (note that the definition has changed to perform no min/max or clamping,
|
||||
// so a comparison to cl/106323955 is inherently unfair)
|
||||
// The GPU test ran with -c opt --config=gcudacc --copt=-mavx, CPU ran without
|
||||
// --config=gcudacc because for some reason that killed throughput measurement.
|
||||
// CPU: Intel Haswell with HyperThreading (6 cores) dL1:32KB dL2:256KB dL3:15MB
|
||||
// GPU: Tesla K40m
|
||||
// BM_AdjustContrast_cpu_1_299_299 179084 340186 2181 751.9M items/s
|
||||
// BM_AdjustContrast_gpu_32_299_299 85276 123665 4189 2.9G items/s
|
||||
BM_AdjustContrastDev(cpu, 1, 299, 299);
|
||||
BM_AdjustContrastDev(gpu, 32, 299, 299);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -24,6 +24,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// this is for v2
|
||||
template struct functor::AdjustContrastv2<GPUDevice>;
|
||||
|
||||
// these are for v1
|
||||
template struct functor::AdjustContrast<GPUDevice, uint8>;
|
||||
template struct functor::AdjustContrast<GPUDevice, int8>;
|
||||
template struct functor::AdjustContrast<GPUDevice, int16>;
|
||||
|
@ -36,55 +36,42 @@ class AdjustContrastOpTest : public OpsTestBase {
|
||||
|
||||
TEST_F(AdjustContrastOpTest, Simple_1113) {
|
||||
RequireDefaultOps();
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrastv2")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Finalize(node_def()));
|
||||
EXPECT_OK(InitOp());
|
||||
AddInputFromArray<float>(TensorShape({1, 1, 1, 3}), {-1, 2, 3});
|
||||
AddInputFromArray<float>(TensorShape({}), {1.0});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0});
|
||||
AddInputFromArray<float>(TensorShape({}), {2.0});
|
||||
ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 3}));
|
||||
test::FillValues<float>(&expected, {0, 2, 2});
|
||||
test::FillValues<float>(&expected, {-1, 2, 3});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AdjustContrastOpTest, Simple_1223) {
|
||||
RequireDefaultOps();
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrastv2")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Finalize(node_def()));
|
||||
EXPECT_OK(InitOp());
|
||||
AddInputFromArray<float>(TensorShape({1, 2, 2, 3}),
|
||||
{1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.2});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0});
|
||||
AddInputFromArray<float>(TensorShape({}), {10.0});
|
||||
ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 3}));
|
||||
test::FillValues<float>(
|
||||
&expected, {2.2, 6.2, 10, 2.4, 6.4, 10, 2.6, 6.6, 10, 2.8, 6.8, 10});
|
||||
test::FillValues<float>(&expected, {2.2, 6.2, 10.2, 2.4, 6.4, 10.4, 2.6, 6.6,
|
||||
10.6, 2.8, 6.8, 10.8});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(AdjustContrastOpTest, Big_99x99x3) {
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
|
||||
EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrastv2")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Finalize(node_def()));
|
||||
EXPECT_OK(InitOp());
|
||||
|
||||
@ -95,8 +82,6 @@ TEST_F(AdjustContrastOpTest, Big_99x99x3) {
|
||||
|
||||
AddInputFromArray<float>(TensorShape({1, 99, 99, 3}), values);
|
||||
AddInputFromArray<float>(TensorShape({}), {0.2});
|
||||
AddInputFromArray<float>(TensorShape({}), {0});
|
||||
AddInputFromArray<float>(TensorShape({}), {255});
|
||||
ASSERT_OK(RunOpKernel());
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
@ -102,6 +102,22 @@ using ::tensorflow::error::OK;
|
||||
// ...
|
||||
// }
|
||||
|
||||
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
|
||||
#define OP_DEPRECATED(CTX, VERSION) \
|
||||
if ((CTX)->graph_def_version() >= (VERSION)) { \
|
||||
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
|
||||
"Op ", (CTX)->op_def().name(), \
|
||||
" is not available in GraphDef version ", (CTX)->graph_def_version(), \
|
||||
". It has been removed in version ", (VERSION), ".")); \
|
||||
VLOG(1) << _s; \
|
||||
(CTX)->SetStatus(_s); \
|
||||
return; \
|
||||
} else { \
|
||||
LOG(WARNING) << "Op is deprecated." \
|
||||
<< " It will cease to work in GraphDef version " << (VERSION) \
|
||||
<< "."; \
|
||||
}
|
||||
|
||||
#define OP_REQUIRES(CTX, EXP, STATUS) \
|
||||
if (!(EXP)) { \
|
||||
::tensorflow::Status _s(STATUS); \
|
||||
|
@ -290,10 +290,9 @@ This operation creates a tensor of shape `dims` and fills it with `value`.
|
||||
For example:
|
||||
|
||||
```prettyprint
|
||||
# output tensor shape needs to be [2, 3]
|
||||
# so 'dims' is [2, 3]
|
||||
fill(dims, 9) ==> [[9, 9, 9]
|
||||
[9, 9, 9]]
|
||||
# Output tensor has shape [2, 3].
|
||||
fill([2, 3], 9) ==> [[9, 9, 9]
|
||||
[9, 9, 9]]
|
||||
```
|
||||
|
||||
dims: 1-D. Represents the shape of the output tensor.
|
||||
|
@ -244,6 +244,15 @@ REGISTER_OP("AdjustContrast")
|
||||
.Output("output: float")
|
||||
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
|
||||
.Doc(R"Doc(
|
||||
Deprecated. Disallowed in GraphDef version >= 2.
|
||||
)Doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("AdjustContrastv2")
|
||||
.Input("images: float")
|
||||
.Input("contrast_factor: float")
|
||||
.Output("output: float")
|
||||
.Doc(R"Doc(
|
||||
Adjust the contrast of one or more images.
|
||||
|
||||
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||
@ -256,13 +265,8 @@ For each channel, the Op first computes the mean of the image pixels in the
|
||||
channel and then adjusts each component of each pixel to
|
||||
`(x - mean) * contrast_factor + mean`.
|
||||
|
||||
These adjusted values are then clipped to fit in the `[min_value, max_value]`
|
||||
interval.
|
||||
|
||||
`images: Images to adjust. At least 3-D.
|
||||
images: Images to adjust. At least 3-D.
|
||||
contrast_factor: A float multiplier for adjusting contrast.
|
||||
min_value: Minimum value for clipping the adjusted pixels.
|
||||
max_value: Maximum value for clipping the adjusted pixels.
|
||||
output: The constrast-adjusted image or images.
|
||||
)Doc");
|
||||
|
||||
|
@ -106,22 +106,18 @@ op {
|
||||
}
|
||||
input_arg {
|
||||
name: "contrast_factor"
|
||||
description: "A float multiplier for adjusting contrast."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "min_value"
|
||||
description: "Minimum value for clipping the adjusted pixels."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "max_value"
|
||||
description: "Maximum value for clipping the adjusted pixels."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "The constrast-adjusted image or images."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
attr {
|
||||
@ -139,8 +135,27 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Deprecated. Disallowed in GraphDef version >= 2."
|
||||
}
|
||||
op {
|
||||
name: "AdjustContrastv2"
|
||||
input_arg {
|
||||
name: "images"
|
||||
description: "Images to adjust. At least 3-D."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
input_arg {
|
||||
name: "contrast_factor"
|
||||
description: "A float multiplier for adjusting contrast."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "The constrast-adjusted image or images."
|
||||
type: DT_FLOAT
|
||||
}
|
||||
summary: "Adjust the contrast of one or more images."
|
||||
description: "`images` is a tensor of at least 3 dimensions. The last 3 dimensions are\ninterpreted as `[height, width, channels]`. The other dimensions only\nrepresent a collection of images, such as `[batch, height, width, channels].`\n\nContrast is adjusted independently for each channel of each image.\n\nFor each channel, the Op first computes the mean of the image pixels in the\nchannel and then adjusts each component of each pixel to\n`(x - mean) * contrast_factor + mean`.\n\nThese adjusted values are then clipped to fit in the `[min_value, max_value]`\ninterval.\n\n`images: Images to adjust. At least 3-D."
|
||||
description: "`images` is a tensor of at least 3 dimensions. The last 3 dimensions are\ninterpreted as `[height, width, channels]`. The other dimensions only\nrepresent a collection of images, such as `[batch, height, width, channels].`\n\nContrast is adjusted independently for each channel of each image.\n\nFor each channel, the Op first computes the mean of the image pixels in the\nchannel and then adjusts each component of each pixel to\n`(x - mean) * contrast_factor + mean`."
|
||||
}
|
||||
op {
|
||||
name: "All"
|
||||
@ -2594,7 +2609,7 @@ op {
|
||||
type: "type"
|
||||
}
|
||||
summary: "Creates a tensor filled with a scalar value."
|
||||
description: "This operation creates a tensor of shape `dims` and fills it with `value`.\n\nFor example:\n\n```prettyprint\n# output tensor shape needs to be [2, 3]\n# so \'dims\' is [2, 3]\nfill(dims, 9) ==> [[9, 9, 9]\n [9, 9, 9]]\n```"
|
||||
description: "This operation creates a tensor of shape `dims` and fills it with `value`.\n\nFor example:\n\n```prettyprint\n# Output tensor has shape [2, 3].\nfill([2, 3], 9) ==> [[9, 9, 9]\n [9, 9, 9]]\n```"
|
||||
}
|
||||
op {
|
||||
name: "FixedLengthRecordReader"
|
||||
|
@ -140,10 +140,9 @@ This operation creates a tensor of shape `dims` and fills it with `value`.
|
||||
For example:
|
||||
|
||||
```prettyprint
|
||||
# output tensor shape needs to be [2, 3]
|
||||
# so 'dims' is [2, 3]
|
||||
fill(dims, 9) ==> [[9, 9, 9]
|
||||
[9, 9, 9]]
|
||||
# Output tensor has shape [2, 3].
|
||||
fill([2, 3], 9) ==> [[9, 9, 9]
|
||||
[9, 9, 9]]
|
||||
```
|
||||
|
||||
##### Args:
|
||||
|
@ -419,6 +419,7 @@ tf_gen_op_wrapper_py(
|
||||
hidden = [
|
||||
"ResizeBilinearGrad",
|
||||
"ResizeNearestNeighborGrad",
|
||||
"AdjustContrastv2",
|
||||
"ScaleImageGrad",
|
||||
],
|
||||
require_shape_functions = True,
|
||||
|
@ -627,25 +627,19 @@ def random_brightness(image, max_delta, seed=None):
|
||||
Equivalent to `adjust_brightness()` using a `delta` randomly picked in the
|
||||
interval `[-max_delta, max_delta)`.
|
||||
|
||||
Note that `delta` is picked as a float. Because for integer type images,
|
||||
the brightness adjusted result is rounded before casting, integer images may
|
||||
have modifications in the range `[-max_delta,max_delta]`.
|
||||
|
||||
Args:
|
||||
image: 3-D tensor of shape `[height, width, channels]`.
|
||||
image: An image.
|
||||
max_delta: float, must be non-negative.
|
||||
seed: A Python integer. Used to create a random seed. See
|
||||
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||
for behavior.
|
||||
|
||||
Returns:
|
||||
3-D tensor of images of shape `[height, width, channels]`
|
||||
The brightness-adjusted image.
|
||||
|
||||
Raises:
|
||||
ValueError: if `max_delta` is negative.
|
||||
"""
|
||||
_Check3DImage(image)
|
||||
|
||||
if max_delta < 0:
|
||||
raise ValueError('max_delta must be non-negative.')
|
||||
|
||||
@ -660,7 +654,7 @@ def random_contrast(image, lower, upper, seed=None):
|
||||
picked in the interval `[lower, upper]`.
|
||||
|
||||
Args:
|
||||
image: 3-D tensor of shape `[height, width, channels]`.
|
||||
image: An image tensor with 3 or more dimensions.
|
||||
lower: float. Lower bound for the random contrast factor.
|
||||
upper: float. Upper bound for the random contrast factor.
|
||||
seed: A Python integer. Used to create a random seed. See
|
||||
@ -668,13 +662,11 @@ def random_contrast(image, lower, upper, seed=None):
|
||||
for behavior.
|
||||
|
||||
Returns:
|
||||
3-D tensor of shape `[height, width, channels]`.
|
||||
The contrast-adjusted tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: if `upper <= lower` or if `lower < 0`.
|
||||
"""
|
||||
_Check3DImage(image)
|
||||
|
||||
if upper <= lower:
|
||||
raise ValueError('upper must be > lower.')
|
||||
|
||||
@ -686,104 +678,82 @@ def random_contrast(image, lower, upper, seed=None):
|
||||
return adjust_contrast(image, contrast_factor)
|
||||
|
||||
|
||||
def adjust_brightness(image, delta, min_value=None, max_value=None):
|
||||
def adjust_brightness(image, delta):
|
||||
"""Adjust the brightness of RGB or Grayscale images.
|
||||
|
||||
The value `delta` is added to all components of the tensor `image`. `image`
|
||||
and `delta` are cast to `float` before adding, and the resulting values are
|
||||
clamped to `[min_value, max_value]`. Finally, the result is cast back to
|
||||
`images.dtype`.
|
||||
This is a convenience method that converts an RGB image to float
|
||||
representation, adjusts its brightness, and then converts it back to the
|
||||
original data type. If several adjustments are chained it is advisable to
|
||||
minimize the number of redundant conversions.
|
||||
|
||||
If `min_value` or `max_value` are not given, they are set to the minimum and
|
||||
maximum allowed values for `image.dtype` respectively.
|
||||
The value `delta` is added to all components of the tensor `image`. Both
|
||||
`image` and `delta` are converted to `float` before adding (and `image` is
|
||||
scaled appropriately if it is in fixed-point representation). For regular
|
||||
images, `delta` should be in the range `[0,1)`, as it is added to the image in
|
||||
floating point representation, where pixel values are in the `[0,1)` range.
|
||||
|
||||
Args:
|
||||
image: A tensor.
|
||||
delta: A scalar. Amount to add to the pixel values.
|
||||
min_value: Minimum value for output.
|
||||
max_value: Maximum value for output.
|
||||
|
||||
Returns:
|
||||
A tensor of the same shape and type as `image`.
|
||||
A brightness-adjusted tensor of the same shape and type as `image`.
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = image.dtype.min
|
||||
if max_value is None:
|
||||
max_value = image.dtype.max
|
||||
with ops.op_scope([image, delta], None, 'adjust_brightness') as name:
|
||||
# Remember original dtype to so we can convert back if needed
|
||||
orig_dtype = image.dtype
|
||||
flt_image = convert_image_dtype(image, dtypes.float32)
|
||||
|
||||
with ops.op_scope([image, delta, min_value, max_value], None,
|
||||
'adjust_brightness') as name:
|
||||
adjusted = math_ops.add(
|
||||
math_ops.cast(image, dtypes.float32),
|
||||
math_ops.cast(delta, dtypes.float32),
|
||||
name=name)
|
||||
if image.dtype.is_integer:
|
||||
rounded = math_ops.round(adjusted)
|
||||
else:
|
||||
rounded = adjusted
|
||||
clipped = clip_ops.clip_by_value(rounded, float(min_value),
|
||||
float(max_value))
|
||||
output = math_ops.cast(clipped, image.dtype)
|
||||
return output
|
||||
adjusted = math_ops.add(flt_image,
|
||||
math_ops.cast(delta, dtypes.float32),
|
||||
name=name)
|
||||
|
||||
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
|
||||
|
||||
|
||||
def adjust_contrast(images, contrast_factor, min_value=None, max_value=None):
|
||||
def adjust_contrast(images, contrast_factor):
|
||||
"""Adjust contrast of RGB or grayscale images.
|
||||
|
||||
This is a convenience method that converts an RGB image to float
|
||||
representation, adjusts its contrast, and then converts it back to the
|
||||
original data type. If several adjustments are chained it is advisable to
|
||||
minimize the number of redundant conversions.
|
||||
|
||||
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||
interpreted as `[height, width, channels]`. The other dimensions only
|
||||
represent a collection of images, such as `[batch, height, width, channels].`
|
||||
|
||||
Contrast is adjusted independently for each channel of each image.
|
||||
|
||||
For each channel, this Op first computes the mean of the image pixels in the
|
||||
For each channel, this Op computes the mean of the image pixels in the
|
||||
channel and then adjusts each component `x` of each pixel to
|
||||
`(x - mean) * contrast_factor + mean`.
|
||||
|
||||
The adjusted values are then clipped to fit in the `[min_value, max_value]`
|
||||
interval. If `min_value` or `max_value` is not given, it is replaced with the
|
||||
minimum and maximum values for the data type of `images` respectively.
|
||||
|
||||
The contrast-adjusted image is always computed as `float`, and it is
|
||||
cast back to its original type after clipping.
|
||||
|
||||
Args:
|
||||
images: Images to adjust. At least 3-D.
|
||||
contrast_factor: A float multiplier for adjusting contrast.
|
||||
min_value: Minimum value for clipping the adjusted pixels.
|
||||
max_value: Maximum value for clipping the adjusted pixels.
|
||||
|
||||
Returns:
|
||||
The constrast-adjusted image or images.
|
||||
|
||||
Raises:
|
||||
ValueError: if the arguments are invalid.
|
||||
"""
|
||||
_CheckAtLeast3DImage(images)
|
||||
with ops.op_scope([images, contrast_factor], None, 'adjust_contrast') as name:
|
||||
# Remember original dtype to so we can convert back if needed
|
||||
orig_dtype = images.dtype
|
||||
flt_images = convert_image_dtype(images, dtypes.float32)
|
||||
|
||||
# If these are None, the min/max should be a nop, but still prevent overflows
|
||||
# from the cast back to images.dtype at the end of adjust_contrast.
|
||||
if min_value is None:
|
||||
min_value = images.dtype.min
|
||||
if max_value is None:
|
||||
max_value = images.dtype.max
|
||||
# pylint: disable=protected-access
|
||||
adjusted = gen_image_ops._adjust_contrastv2(flt_images,
|
||||
contrast_factor=contrast_factor,
|
||||
name=name)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
with ops.op_scope(
|
||||
[images, contrast_factor, min_value,
|
||||
max_value], None, 'adjust_contrast') as name:
|
||||
adjusted = gen_image_ops.adjust_contrast(images,
|
||||
contrast_factor=contrast_factor,
|
||||
min_value=min_value,
|
||||
max_value=max_value,
|
||||
name=name)
|
||||
if images.dtype.is_integer:
|
||||
return math_ops.cast(math_ops.round(adjusted), images.dtype)
|
||||
else:
|
||||
return math_ops.cast(adjusted, images.dtype)
|
||||
return convert_image_dtype(adjusted, orig_dtype, saturate=True)
|
||||
|
||||
|
||||
ops.RegisterShape('AdjustContrast')(
|
||||
common_shapes.unchanged_shape_with_rank_at_least(3))
|
||||
ops.RegisterShape('AdjustContrastv2')(
|
||||
common_shapes.unchanged_shape_with_rank_at_least(3))
|
||||
|
||||
|
||||
@ops.RegisterShape('ResizeBilinear')
|
||||
@ -861,7 +831,37 @@ def random_crop(image, size, seed=None, name=None):
|
||||
name=name)
|
||||
|
||||
|
||||
def convert_image_dtype(image, dtype, name=None):
|
||||
def saturate_cast(image, dtype):
|
||||
"""Performs a safe cast of image data to `dtype`.
|
||||
|
||||
This function casts the data in image to `dtype`, without applying any
|
||||
scaling. If there is a danger that image data would over or underflow in the
|
||||
cast, this op applies the appropriate clamping before the cast.
|
||||
|
||||
Args:
|
||||
image: An image to cast to a different data type.
|
||||
dtype: A `DType` to cast `image` to.
|
||||
|
||||
Returns:
|
||||
`image`, safely cast to `dtype`.
|
||||
"""
|
||||
clamped = image
|
||||
|
||||
# When casting to a type with smaller representable range, clamp.
|
||||
# Note that this covers casting to unsigned types as well.
|
||||
if image.dtype.min < dtype.min and image.dtype.max > dtype.max:
|
||||
clamped = clip_ops.clip_by_value(clamped,
|
||||
math_ops.cast(dtype.min, image.dtype),
|
||||
math_ops.cast(dtype.max, image.dtype))
|
||||
elif image.dtype.min < dtype.min:
|
||||
clamped = math_ops.maximum(clamped, math_ops.cast(dtype.min, image.dtype))
|
||||
elif image.dtype.max > dtype.max:
|
||||
clamped = math_ops.minimum(clamped, math_ops.cast(dtype.max, image.dtype))
|
||||
|
||||
return math_ops.cast(clamped, dtype)
|
||||
|
||||
|
||||
def convert_image_dtype(image, dtype, saturate=False, name=None):
|
||||
"""Convert `image` to `dtype`, scaling its values if needed.
|
||||
|
||||
Images that are represented using floating point values are expected to have
|
||||
@ -872,13 +872,17 @@ def convert_image_dtype(image, dtype, name=None):
|
||||
This op converts between data types, scaling the values appropriately before
|
||||
casting.
|
||||
|
||||
Note that for floating point inputs, this op expects values to lie in [0,1).
|
||||
Conversion of an image containing values outside that range may lead to
|
||||
overflow errors when converted to integer `Dtype`s.
|
||||
Note that converting from floating point inputs to integer types may lead to
|
||||
over/underflow problems. Set saturate to `True` to avoid such problem in
|
||||
problematic conversions. Saturation will clip the output into the allowed
|
||||
range before performing a potentially dangerous cast (i.e. when casting from
|
||||
a floating point to an integer type, or when casting from an signed to an
|
||||
unsigned type).
|
||||
|
||||
Args:
|
||||
image: An image.
|
||||
dtype: A `DType` to convert `image` to.
|
||||
saturate: If `True`, clip the input before casting (if necessary).
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -899,19 +903,28 @@ def convert_image_dtype(image, dtype, name=None):
|
||||
# so that the output is safely in the supported range.
|
||||
scale = (scale_in + 1) // (scale_out + 1)
|
||||
scaled = math_ops.div(image, scale)
|
||||
return math_ops.cast(scaled, dtype)
|
||||
|
||||
if saturate:
|
||||
return saturate_cast(scaled, dtype)
|
||||
else:
|
||||
return math_ops.cast(scaled, dtype)
|
||||
else:
|
||||
# Scaling up, cast first, then scale. The scale will not map in.max to
|
||||
# out.max, but converting back and forth should result in no change.
|
||||
cast = math_ops.cast(image, dtype)
|
||||
if saturate:
|
||||
cast = saturate_cast(scaled, dtype)
|
||||
else:
|
||||
cast = math_ops.cast(image, dtype)
|
||||
scale = (scale_out + 1) // (scale_in + 1)
|
||||
return math_ops.mul(cast, scale)
|
||||
elif image.dtype.is_floating and dtype.is_floating:
|
||||
# Both float: Just cast, no possible overflows in the allowed ranges.
|
||||
# Note: We're ignoreing float overflows. If your image dynamic range
|
||||
# exceeds float range you're on your own.
|
||||
return math_ops.cast(image, dtype)
|
||||
else:
|
||||
if image.dtype.is_integer:
|
||||
# Converting to float: first cast, then scale
|
||||
# Converting to float: first cast, then scale. No saturation possible.
|
||||
cast = math_ops.cast(image, dtype)
|
||||
scale = 1. / image.dtype.max
|
||||
return math_ops.mul(cast, scale)
|
||||
@ -919,7 +932,10 @@ def convert_image_dtype(image, dtype, name=None):
|
||||
# Converting from float: first scale, then cast
|
||||
scale = dtype.max + 0.5 # avoid rounding problems in the cast
|
||||
scaled = math_ops.mul(image, scale)
|
||||
return math_ops.cast(scaled, dtype)
|
||||
if saturate:
|
||||
return saturate_cast(scaled, dtype)
|
||||
else:
|
||||
return math_ops.cast(scaled, dtype)
|
||||
|
||||
|
||||
def rgb_to_grayscale(images):
|
||||
|
@ -302,85 +302,64 @@ class RandomFlipTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class AdjustContrastTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _testContrast(self, x_np, y_np, contrast_factor, min_value, max_value):
|
||||
def _testContrast(self, x_np, y_np, contrast_factor):
|
||||
for use_gpu in [True, False]:
|
||||
with self.test_session(use_gpu=use_gpu):
|
||||
x = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y = image_ops.adjust_contrast(x,
|
||||
contrast_factor,
|
||||
min_value=min_value,
|
||||
max_value=max_value)
|
||||
y = image_ops.adjust_contrast(x, contrast_factor)
|
||||
y_tf = y.eval()
|
||||
self.assertAllEqual(y_tf, y_np)
|
||||
self.assertAllClose(y_tf, y_np, 1e-6)
|
||||
|
||||
def testDoubleContrastUint8(self):
|
||||
x_shape = [1, 2, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
y_data = [0, 0, 0, 63, 169, 255, 29, 0, 255, 135, 255, 0]
|
||||
y_data = [0, 0, 0, 62, 169, 255, 28, 0, 255, 135, 255, 0]
|
||||
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
self._testContrast(x_np,
|
||||
y_np,
|
||||
contrast_factor=2.0,
|
||||
min_value=None,
|
||||
max_value=None)
|
||||
self._testContrast(x_np, y_np, contrast_factor=2.0)
|
||||
|
||||
def testDoubleContrastFloat(self):
|
||||
x_shape = [1, 2, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.float).reshape(x_shape)
|
||||
x_np = np.array(x_data, dtype=np.float).reshape(x_shape) / 255.
|
||||
|
||||
y_data = [0, 0, 0, 62.75, 169.25, 255, 28.75, 0, 255, 134.75, 255, 0]
|
||||
y_np = np.array(y_data, dtype=np.float).reshape(x_shape)
|
||||
y_data = [-45.25, -90.75, -92.5, 62.75, 169.25, 333.5, 28.75, -84.75, 349.5,
|
||||
134.75, 409.25, -116.5]
|
||||
y_np = np.array(y_data, dtype=np.float).reshape(x_shape) / 255.
|
||||
|
||||
self._testContrast(x_np,
|
||||
y_np,
|
||||
contrast_factor=2.0,
|
||||
min_value=0,
|
||||
max_value=255)
|
||||
self._testContrast(x_np, y_np, contrast_factor=2.0)
|
||||
|
||||
def testHalfContrastUint8(self):
|
||||
x_shape = [1, 2, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
y_data = [23, 53, 66, 50, 118, 172, 41, 54, 176, 68, 178, 60]
|
||||
y_data = [22, 52, 65, 49, 118, 172, 41, 54, 176, 67, 178, 59]
|
||||
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
self._testContrast(x_np,
|
||||
y_np,
|
||||
contrast_factor=0.5,
|
||||
min_value=None,
|
||||
max_value=None)
|
||||
self._testContrast(x_np, y_np, contrast_factor=0.5)
|
||||
|
||||
def testBatchDoubleContrast(self):
|
||||
x_shape = [2, 1, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
y_data = [0, 0, 0, 81, 200, 255, 11, 0, 255, 117, 255, 0]
|
||||
y_data = [0, 0, 0, 81, 200, 255, 10, 0, 255, 116, 255, 0]
|
||||
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
self._testContrast(x_np,
|
||||
y_np,
|
||||
contrast_factor=2.0,
|
||||
min_value=None,
|
||||
max_value=None)
|
||||
self._testContrast(x_np, y_np, contrast_factor=2.0)
|
||||
|
||||
|
||||
class AdjustBrightnessTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _testBrightness(self, x_np, y_np, delta, min_value, max_value):
|
||||
def _testBrightness(self, x_np, y_np, delta):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y = image_ops.adjust_brightness(x,
|
||||
delta,
|
||||
min_value=min_value,
|
||||
max_value=max_value)
|
||||
y = image_ops.adjust_brightness(x, delta)
|
||||
y_tf = y.eval()
|
||||
self.assertAllEqual(y_tf, y_np)
|
||||
self.assertAllClose(y_tf, y_np, 1e-6)
|
||||
|
||||
def testPositiveDeltaUint8(self):
|
||||
x_shape = [2, 2, 3]
|
||||
@ -390,27 +369,27 @@ class AdjustBrightnessTest(test_util.TensorFlowTestCase):
|
||||
y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 255, 11]
|
||||
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
|
||||
self._testBrightness(x_np, y_np, delta=10. / 255.)
|
||||
|
||||
def testPositiveDeltaFloat(self):
|
||||
x_shape = [2, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.float32).reshape(x_shape)
|
||||
x_np = np.array(x_data, dtype=np.float32).reshape(x_shape) / 255.
|
||||
|
||||
y_data = [10, 15, 23, 64, 145, 236, 47, 18, 244, 100, 265, 11]
|
||||
y_np = np.array(y_data, dtype=np.float32).reshape(x_shape)
|
||||
y_np = np.array(y_data, dtype=np.float32).reshape(x_shape) / 255.
|
||||
|
||||
self._testBrightness(x_np, y_np, delta=10.0, min_value=None, max_value=None)
|
||||
self._testBrightness(x_np, y_np, delta=10. / 255.)
|
||||
|
||||
def testNegativeDelta(self):
|
||||
x_shape = [2, 2, 3]
|
||||
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
|
||||
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
y_data = [5, 5, 5, 44, 125, 216, 27, 5, 224, 80, 245, 5]
|
||||
y_data = [0, 0, 3, 44, 125, 216, 27, 0, 224, 80, 245, 0]
|
||||
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
|
||||
|
||||
self._testBrightness(x_np, y_np, delta=-10.0, min_value=5, max_value=None)
|
||||
self._testBrightness(x_np, y_np, delta=-10. / 255.)
|
||||
|
||||
|
||||
class RandomCropTest(test_util.TensorFlowTestCase):
|
||||
@ -955,7 +934,7 @@ class PngTest(test_util.TensorFlowTestCase):
|
||||
self.assertLessEqual(len(png0), 750)
|
||||
|
||||
def testShape(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
png = constant_op.constant('nonsense')
|
||||
for channels in 0, 1, 3:
|
||||
image = image_ops.decode_png(png, channels=channels)
|
||||
|
@ -1,7 +1,18 @@
|
||||
FROM tensorflow:ci_build.cpu
|
||||
FROM ubuntu:14.04
|
||||
|
||||
MAINTAINER Jan Prach <jendap@google.com>
|
||||
|
||||
# Copy ci_build install scripts into the container.
|
||||
COPY install /install
|
||||
|
||||
# Run the install scripts.
|
||||
RUN install/install_deb_packages.sh
|
||||
RUN install/install_openjdk8_from_ppa.sh
|
||||
RUN install/install_bazel.sh
|
||||
|
||||
# Set up BAZELRC environment variable.
|
||||
ENV BAZELRC /root/.bazelrc
|
||||
|
||||
# Install Android SDK.
|
||||
ENV ANDROID_SDK_FILENAME android-sdk_r24.4.1-linux.tgz
|
||||
ENV ANDROID_SDK_URL http://dl.google.com/android/${ANDROID_SDK_FILENAME}
|
||||
|
@ -2,51 +2,13 @@ FROM ubuntu:14.04
|
||||
|
||||
MAINTAINER Jan Prach <jendap@google.com>
|
||||
|
||||
# Install dependencies for bazel.
|
||||
RUN apt-get update && apt-get install -y \
|
||||
g++ \
|
||||
pkg-config \
|
||||
python-dev \
|
||||
python-numpy \
|
||||
python-pip \
|
||||
software-properties-common \
|
||||
swig \
|
||||
unzip \
|
||||
wget \
|
||||
zip \
|
||||
zlib1g-dev \
|
||||
&& \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Copy ci_build install scripts into the container.
|
||||
COPY install /install
|
||||
|
||||
# Install openjdk 8 for bazel from PPA (it is not available in 14.04).
|
||||
RUN add-apt-repository -y ppa:openjdk-r/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y openjdk-8-jdk openjdk-8-jre-headless && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Run the install scripts.
|
||||
RUN install/install_deb_packages.sh
|
||||
RUN install/install_openjdk8_from_ppa.sh
|
||||
RUN install/install_bazel.sh
|
||||
|
||||
# Install the most recent bazel release.
|
||||
ENV BAZEL_VERSION 0.1.1
|
||||
RUN mkdir /bazel && \
|
||||
cd /bazel && \
|
||||
wget https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
|
||||
wget -O /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
|
||||
chmod +x bazel-*.sh && \
|
||||
./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
|
||||
rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
|
||||
|
||||
# Enable bazel auto completion.
|
||||
RUN echo "source /usr/local/lib/bazel/bin/bazel-complete.bash" >> ~/.bashrc
|
||||
|
||||
# Running bazel inside a `docker build` command causes trouble, cf:
|
||||
# https://github.com/bazelbuild/bazel/issues/134
|
||||
# The easiest solution is to set up a bazelrc file forcing --batch.
|
||||
RUN echo "startup --batch" >>/root/.bazelrc
|
||||
# Similarly, we need to workaround sandboxing issues:
|
||||
# https://github.com/bazelbuild/bazel/issues/418
|
||||
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
||||
>>/root/.bazelrc
|
||||
# Force bazel output to use colors (good for jenkins).
|
||||
RUN echo "common --color=yes" >>/root/.bazelrc
|
||||
# Set up BAZELRC environment variable.
|
||||
ENV BAZELRC /root/.bazelrc
|
||||
|
@ -1,23 +1,23 @@
|
||||
FROM tensorflow:ci_build.cpu
|
||||
FROM nvidia/cuda:7.0-cudnn2-devel
|
||||
|
||||
MAINTAINER Jan Prach <jendap@google.com>
|
||||
|
||||
# Install Cuda.
|
||||
RUN cd /tmp && \
|
||||
wget http://developer.download.nvidia.com/compute/cuda/7_0/Prod/local_installers/cuda_7.0.28_linux.run && \
|
||||
chmod +x *.run && ./cuda_*_linux.run -extract=`pwd` && \
|
||||
./NVIDIA-Linux-x86_64-*.run -s --no-kernel-module && \
|
||||
./cuda-linux64-rel-*.run -noprompt && \
|
||||
rm -rf *
|
||||
# Copy ci_build install scripts into the container.
|
||||
COPY install /install
|
||||
|
||||
# Set up CUDA variables in .bashrc
|
||||
RUN echo "CUDA_PATH=/usr/local/cuda" >>~/.bash_profile && \
|
||||
echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64:/tensorflow_extra_deps/cudnn-6.5-linux-x64-v2" >>~/.bash_profile
|
||||
# Run the install scripts.
|
||||
RUN install/install_deb_packages.sh
|
||||
RUN install/install_openjdk8_from_ppa.sh
|
||||
RUN install/install_bazel.sh
|
||||
|
||||
# Set up cuda variables.
|
||||
# Set up BAZELRC environment variable.
|
||||
ENV BAZELRC /root/.bazelrc
|
||||
|
||||
# Set up CUDA variables
|
||||
ENV CUDA_PATH /usr/local/cuda
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/tensorflow_extra_deps/cudnn-6.5-linux-x64-v2
|
||||
ENV LD_LIBRARY_PATH /usr/local/cuda/lib64
|
||||
|
||||
# Set up variables fo tensorflow to use cuda.
|
||||
# Configure the build for our CUDA configuration.
|
||||
ENV CUDA_TOOLKIT_PATH /usr/local/cuda
|
||||
ENV CUDNN_INSTALL_PATH /tensorflow_extra_deps/cudnn-6.5-linux-x64-v2
|
||||
ENV CUDNN_INSTALL_PATH /usr/local/cuda
|
||||
ENV TF_NEED_CUDA 1
|
||||
|
@ -24,15 +24,11 @@ to docker caching. Individual builds are fast thanks to bazel caching.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
* The unusual `bazel-user-cache-for-docker` directory is mapped to docker
|
||||
* The unusual `bazel-ci_build-cache` directory is mapped to docker
|
||||
container performing the build using docker's --volume parameter.
|
||||
This way we cache bazel output between builds.
|
||||
|
||||
* The `$HOME/.tensorflow_extra_deps` directory contains
|
||||
[cudnn](https://developer.nvidia.com/cudnn).
|
||||
Unfortunatelly this require you to agree a license to download.
|
||||
|
||||
* The builds directory hithin this folder contains shell scripts to run within
|
||||
* The `builds` directory within this folder contains shell scripts to run within
|
||||
the container. They essentially contains workarounds for current limitations
|
||||
of bazel.
|
||||
|
||||
@ -61,21 +57,6 @@ cd tensorflow
|
||||
tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/...
|
||||
```
|
||||
|
||||
**Note**: For GPU you have to create `$HOME/.tensorflow_extra_deps` and manually
|
||||
install there required dependencies (i.e. cudnn) for which you have to agree
|
||||
to licences manually.
|
||||
|
||||
|
||||
#### CUDNN
|
||||
|
||||
For GPU download the [cudnn](https://developer.nvidia.com/cudnn).
|
||||
You will download `cudnn-6.5-linux-x64-v2.tgz`. Run
|
||||
|
||||
```bash
|
||||
mkdir -p $HOME/.tensorflow_extra_deps
|
||||
tar xzf cudnn-6.5-linux-x64-v2.tgz -C $HOME/.tensorflow_extra_deps
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Jobs
|
||||
@ -86,10 +67,10 @@ The jobs run by [ci.tensorflow.org](http://ci.tensorflow.org) include following:
|
||||
# Note: You can run the following one-liners yourself if you have Docker.
|
||||
|
||||
# build and run cpu tests
|
||||
tensorflow/tools/ci_build/ci_build.sh CPU bazel test --test_timeout=1800 //tensorflow/...
|
||||
tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/...
|
||||
|
||||
# build gpu
|
||||
tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/gpu.sh
|
||||
tensorflow/tools/ci_build/ci_build.sh GPU bazel build -c opt --config=cuda //tensorflow/...
|
||||
|
||||
# build pip with gpu support
|
||||
tensorflow/tools/ci_build/ci_build.sh GPU tensorflow/tools/ci_build/builds/gpu_pip.sh
|
||||
|
34
tensorflow/tools/ci_build/builds/configured
Executable file
34
tensorflow/tools/ci_build/builds/configured
Executable file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# This script is a wrapper to run any build inside the docker cotainer
|
||||
# when running ci_build.sh. It's purpose is to automate the call of ./configure.
|
||||
# Yes, this script is a workaround of a workaround.
|
||||
|
||||
set -e
|
||||
|
||||
CONTAINER_TYPE=$( echo "$1" | tr '[:upper:]' '[:lower:]' )
|
||||
shift 1
|
||||
COMMAND=("$@")
|
||||
|
||||
export PYTHON_BIN_PATH=$(which python)
|
||||
if [ "${CONTAINER_TYPE}" == "gpu" ]; then
|
||||
export TF_NEED_CUDA=1
|
||||
fi
|
||||
|
||||
./configure
|
||||
|
||||
${COMMAND[@]}
|
@ -16,8 +16,6 @@
|
||||
|
||||
set -e
|
||||
|
||||
export TF_NEED_CUDA=1
|
||||
./configure
|
||||
bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package
|
||||
rm -rf /root/.cache/tensorflow-pip
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /root/.cache/tensorflow-pip
|
||||
|
@ -47,10 +47,6 @@ function upsearch () {
|
||||
WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
|
||||
BUILD_TAG="${BUILD_TAG:-tf_ci}"
|
||||
|
||||
# Additional configuration. You can customize it by modifying
|
||||
# env variable.
|
||||
EXTRA_DEPS_DIR="${EXTRA_DEPS_DIR:-${HOME}/.tensorflow_extra_deps}"
|
||||
|
||||
|
||||
# Print arguments.
|
||||
echo "CONTAINER_TYPE: ${CONTAINER_TYPE}"
|
||||
@ -58,31 +54,21 @@ echo "COMMAND: ${COMMAND[@]}"
|
||||
echo "WORKSAPCE: ${WORKSPACE}"
|
||||
echo "BUILD_TAG: ${BUILD_TAG}"
|
||||
echo " (docker container name will be ${BUILD_TAG}.${CONTAINER_TYPE})"
|
||||
echo "EXTRA_DEPS_DIR: ${EXTRA_DEPS_DIR}"
|
||||
echo ""
|
||||
|
||||
# Build the docker containers.
|
||||
echo "Building CPU container (${BUILD_TAG}.cpu)..."
|
||||
docker build -t ${BUILD_TAG}.cpu -f ${SCRIPT_DIR}/Dockerfile.cpu ${SCRIPT_DIR}
|
||||
if [ "${CONTAINER_TYPE}" != "cpu" ]; then
|
||||
echo "Building container ${BUILD_TAG}.${CONTAINER_TYPE}..."
|
||||
tmp_dockerfile="${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE}.${BUILD_TAG}"
|
||||
# we need to generate temporary dockerfile with overwritten FROM directive
|
||||
sed "s/^FROM .*/FROM ${BUILD_TAG}.cpu/" \
|
||||
${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE} > ${tmp_dockerfile}
|
||||
docker build -t ${BUILD_TAG}.${CONTAINER_TYPE} \
|
||||
-f ${tmp_dockerfile} ${SCRIPT_DIR}
|
||||
rm ${tmp_dockerfile}
|
||||
fi
|
||||
|
||||
# Build the docker container.
|
||||
echo "Building container (${BUILD_TAG}.${CONTAINER_TYPE})..."
|
||||
docker build -t ${BUILD_TAG}.${CONTAINER_TYPE} \
|
||||
-f ${SCRIPT_DIR}/Dockerfile.${CONTAINER_TYPE} ${SCRIPT_DIR}
|
||||
|
||||
|
||||
# Run the command inside the container.
|
||||
echo "Running '${COMMAND[@]}' inside ${BUILD_TAG}.${CONTAINER_TYPE}..."
|
||||
mkdir -p ${WORKSPACE}/bazel-user-cache-for-docker
|
||||
mkdir -p ${WORKSPACE}/bazel-ci_build-cache
|
||||
docker run \
|
||||
-v ${WORKSPACE}/bazel-user-cache-for-docker:/root/.cache \
|
||||
-v ${WORKSPACE}/bazel-ci_build-cache:/root/.cache \
|
||||
-v ${WORKSPACE}:/tensorflow \
|
||||
-v ${EXTRA_DEPS_DIR}:/tensorflow_extra_deps \
|
||||
-w /tensorflow \
|
||||
${BUILD_TAG}.${CONTAINER_TYPE} \
|
||||
"${COMMAND[@]}"
|
||||
"tensorflow/tools/ci_build/builds/configured" "${CONTAINER_TYPE}" "${COMMAND[@]}"
|
||||
|
49
tensorflow/tools/ci_build/install/install_bazel.sh
Executable file
49
tensorflow/tools/ci_build/install/install_bazel.sh
Executable file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Select bazel version.
|
||||
BAZEL_VERSION="0.1.1"
|
||||
|
||||
# Install bazel.
|
||||
mkdir /bazel
|
||||
cd /bazel
|
||||
curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
|
||||
curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt
|
||||
chmod +x /bazel/bazel-*.sh
|
||||
/bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
|
||||
rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
|
||||
|
||||
# Enable bazel auto completion.
|
||||
echo "source /usr/local/lib/bazel/bin/bazel-complete.bash" >> ~/.bashrc
|
||||
|
||||
# Running bazel inside a `docker build` command causes trouble, cf:
|
||||
# https://github.com/bazelbuild/bazel/issues/134
|
||||
# The easiest solution is to set up a bazelrc file forcing --batch.
|
||||
echo "startup --batch" >>/root/.bazelrc
|
||||
# Similarly, we need to workaround sandboxing issues:
|
||||
# https://github.com/bazelbuild/bazel/issues/418
|
||||
echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
|
||||
>>/root/.bazelrc
|
||||
# Force bazel output to use colors (good for jenkins).
|
||||
echo "common --color=yes" >>/root/.bazelrc
|
||||
# Configure tests - increase timeout, print errors and timeout warnings
|
||||
echo "test" \
|
||||
" --test_timeout=3600" \
|
||||
" --test_output=errors" \
|
||||
" --test_verbose_timeout_warnings" \
|
||||
>>/root/.bazelrc
|
36
tensorflow/tools/ci_build/install/install_deb_packages.sh
Executable file
36
tensorflow/tools/ci_build/install/install_deb_packages.sh
Executable file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Install dependencies from ubuntu deb repository.
|
||||
apt-get update
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
pkg-config \
|
||||
python-dev \
|
||||
python-numpy \
|
||||
python-pip \
|
||||
software-properties-common \
|
||||
swig \
|
||||
unzip \
|
||||
wget \
|
||||
zip \
|
||||
zlib1g-dev
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
@ -16,6 +16,9 @@
|
||||
|
||||
set -e
|
||||
|
||||
export TF_NEED_CUDA=1
|
||||
./configure
|
||||
bazel build -c opt --config=cuda //tensorflow/...
|
||||
# Install openjdk 8 for bazel from PPA (it is not available in 14.04).
|
||||
add-apt-repository -y ppa:openjdk-r/ppa
|
||||
apt-get update
|
||||
apt-get install -y openjdk-8-jdk openjdk-8-jre-headless
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
Loading…
Reference in New Issue
Block a user