Merge changes from github.

Change: 123342870
This commit is contained in:
Vijay Vasudevan 2016-05-26 11:05:13 -08:00 committed by TensorFlower Gardener
parent 9a69f398e9
commit 8cc567bf97
33 changed files with 282 additions and 323 deletions

View File

@ -1,3 +1,5 @@
workspace(name = "org_tensorflow")
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
# name = "androidsdk",

View File

@ -36,6 +36,10 @@ cc_binary(
cc_test(
name = "convert_graphdef_memmapped_format_test",
srcs = ["convert_graphdef_memmapped_format_test.cc"],
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
deps = [
":convert_graphdef_memmapped_format_lib",
"//tensorflow/cc:cc_ops",

View File

@ -1231,6 +1231,10 @@ cc_test(
# higher level tests
tf_cc_tests(
size = "small",
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
linkstatic = tf_kernel_tests_linkstatic(),
tests = glob(
[

View File

@ -324,6 +324,10 @@ cc_library(
tf_cuda_cc_tests(
size = "small",
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [],
tests = [

View File

@ -473,6 +473,10 @@ tf_cc_test(
tf_cc_test(
name = "slice_op_test",
size = "small",
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
deps = [
":ops_testutil",
":ops_util",
@ -768,6 +772,10 @@ tf_cc_tests(
)
tf_cc_tests(
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tests = [
"adjust_contrast_op_test",
"colorspace_op_test",
@ -1058,6 +1066,10 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "reduction_ops_test",
size = "small",
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
deps = [
":ops_testutil",
":ops_util",

View File

@ -210,7 +210,7 @@ REGISTER_CPU(complex64);
REGISTER_CPU(complex128);
#if GOOGLE_CUDA
REGISTER_GPU(float);
// REGISTER_GPU(double);
REGISTER_GPU(double);
#if CUDA_VERSION >= 7050
REGISTER_GPU(Eigen::half);
#endif

View File

@ -130,6 +130,10 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
internal::Transpose<Device, uint64>(d, in, perm, out);
break;
case DT_COMPLEX128:
internal::Transpose<Device, float4>(d, in, perm, out);
break;
default:
return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype());
}

View File

@ -6366,9 +6366,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -6391,9 +6389,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -6416,8 +6412,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -6993,9 +6987,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7018,9 +7009,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7043,10 +7031,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
@ -7762,9 +7746,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7787,9 +7768,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7812,10 +7790,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
@ -7837,9 +7811,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7862,9 +7833,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7887,10 +7855,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
@ -7927,9 +7891,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7952,9 +7914,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -7977,8 +7937,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -10263,9 +10221,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -10288,9 +10243,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -10313,10 +10265,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
}
}
@ -10390,9 +10338,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -10415,9 +10361,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -10440,8 +10384,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -16137,9 +16079,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -16162,9 +16102,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -16187,8 +16125,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -17916,9 +17852,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -17941,9 +17875,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -17966,8 +17898,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -18089,9 +18019,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -18114,9 +18042,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -18139,8 +18065,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -20897,9 +20821,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -20922,9 +20844,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -20947,8 +20867,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@ -21696,9 +21614,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -21721,9 +21637,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_COMPLEX64
type: DT_INT64
}
}
}
@ -21746,8 +21660,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}

View File

@ -136,6 +136,14 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
Input("x: T").Output("y: T").Attr( \
"T: {half, float, double, int32, int64, complex64, complex128}")
#define UNARY_REAL() \
Input("x: T").Output("y: T").Attr( \
"T: {half, float, double}")
#define UNARY_COMPLEX() \
Input("x: T").Output("y: T").Attr( \
"T: {half, float, double, complex64, complex128}")
REGISTER_OP("Neg")
.UNARY()
.Doc(R"doc(
@ -158,65 +166,65 @@ I.e., \\(y = x * x = x^2\\).
)doc");
REGISTER_OP("Sqrt")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes square root of x element-wise.
I.e., \\(y = \sqrt{x} = x^{1/2}\\).
)doc");
REGISTER_OP("Rsqrt")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes reciprocal of square root of x element-wise.
I.e., \\(y = 1 / \sqrt{x}\\).
)doc");
REGISTER_OP("Exp")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes exponential of x element-wise. \\(y = e^x\\).
)doc");
REGISTER_OP("Log")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes natural logarithm of x element-wise.
I.e., \\(y = \log_e x\\).
)doc");
REGISTER_OP("Tanh")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes hyperbolic tangent of `x` element-wise.
)doc");
REGISTER_OP("Lgamma")
.UNARY()
.UNARY_REAL()
.Doc(R"doc(
Computes the log of the absolute value of `Gamma(x)` element-wise.
)doc");
REGISTER_OP("Digamma")
.UNARY()
.UNARY_REAL()
.Doc(R"doc(
Computes Psi, the derivative of Lgamma (the log of the absolute value of
`Gamma(x)`), element-wise.
)doc");
REGISTER_OP("Erf")
.UNARY()
.UNARY_REAL()
.Doc(R"doc(
Computes the Gauss error function of `x` element-wise.
)doc");
REGISTER_OP("Erfc")
.UNARY()
.UNARY_REAL()
.Doc(R"doc(
Computes the complementary error function of `x` element-wise.
)doc");
REGISTER_OP("Sigmoid")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes sigmoid of `x` element-wise.
@ -224,18 +232,20 @@ Specifically, `y = 1 / (1 + exp(-x))`.
)doc");
REGISTER_OP("Sin")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes sin of x element-wise.
)doc");
REGISTER_OP("Cos")
.UNARY()
.UNARY_COMPLEX()
.Doc(R"doc(
Computes cos of x element-wise.
)doc");
#undef UNARY
#undef UNARY_REAL
#undef UNARY_COMPLEX
REGISTER_OP("IsNan")
.Input("x: T")

View File

@ -521,13 +521,13 @@ def get_random_distorted_bottlenecks(
ground_truths = []
for unused_i in range(how_many):
label_index = random.randrange(class_count)
label_name = image_lists.keys()[label_index]
label_name = list(image_lists.keys())[label_index]
image_index = random.randrange(65536)
image_path = get_image_path(image_lists, label_name, image_index, image_dir,
category)
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image_path)
jpeg_data = gfile.FastGFile(image_path, 'r').read()
jpeg_data = gfile.FastGFile(image_path, 'rb').read()
# Note that we materialize the distorted_image_data as a numpy array before
# sending running inference on the image. This involves 2 memory copies and
# might be optimized in other implementations.
@ -616,7 +616,7 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
"""
jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
decoded_image = tf.image.decode_jpeg(jpeg_data)
decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH)
decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
margin_scale = 1.0 + (random_crop / 100.0)

View File

@ -60,7 +60,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
`int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
`int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
* <b>`name`</b>: A name for the operation (optional).
@ -119,7 +119,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
`int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
`int8`, `int16`, `int32`, `int64`, `uint8`, `complex64` or `complex128`.
* <b>`name`</b>: A name for the operation (optional).

View File

@ -163,7 +163,7 @@ case where both types are quantized.
* <b>`value`</b>: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
`int16`, `int8`, or `complex64`.
`int16`, `int8`, `complex64` or `complex128`.
* <b>`bias`</b>: A 1-D `Tensor` with size matching the last dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
@ -186,7 +186,7 @@ Specifically, `y = 1 / (1 + exp(-x))`.
##### Args:
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
@ -205,7 +205,7 @@ Computes hyperbolic tangent of `x` element-wise.
##### Args:
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).

View File

@ -658,7 +658,7 @@ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_packag
mkdir _python_build
cd _python_build
ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/* .
ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow* .
ln -s ../tensorflow/tools/pip_package/* .
python setup.py develop
```

View File

@ -19,7 +19,7 @@ import numpy as np
#Imports for visualization
import PIL.Image
from cStringIO import StringIO
from io import BytesIO
from IPython.display import clear_output, Image, display
```
@ -30,8 +30,9 @@ def DisplayArray(a, fmt='jpeg', rng=[0,1]):
"""Display an array as a picture."""
a = (a - rng[0])/float(rng[1] - rng[0])*255
a = np.uint8(np.clip(a, 0, 255))
f = StringIO()
f = BytesIO()
PIL.Image.fromarray(a).save(f, fmt)
clear_output(wait = True)
display(Image(data=f.getvalue()))
```
@ -132,10 +133,7 @@ tf.initialize_all_variables().run()
for i in range(1000):
# Step simulation
step.run({eps: 0.03, damping: 0.04})
# Visualize every 50 steps
if i % 50 == 0:
clear_output()
DisplayArray(U.eval(), rng=[-0.1, 0.1])
DisplayArray(U.eval(), rng=[-0.1, 0.1])
```
![jpeg](../../images/pde_output_2.jpg)

View File

@ -83,17 +83,15 @@ class Seq2SeqModel(object):
softmax_loss_function = None
# Sampled softmax only makes sense if we sample less than vocabulary size.
if num_samples > 0 and num_samples < self.target_vocab_size:
with tf.device("/cpu:0"):
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [self.target_vocab_size])
w = tf.get_variable("proj_w", [size, self.target_vocab_size])
w_t = tf.transpose(w)
b = tf.get_variable("proj_b", [self.target_vocab_size])
output_projection = (w, b)
def sampled_loss(inputs, labels):
with tf.device("/cpu:0"):
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
self.target_vocab_size)
labels = tf.reshape(labels, [-1, 1])
return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
self.target_vocab_size)
softmax_loss_function = sampled_loss
# Create the internal multi-layer cell for our RNN.

View File

@ -24,6 +24,10 @@ from tensorflow.python.framework import ops
_DEFAULT_GRAPH_SEED = 87654321
_MAXINT32 = 2**31 - 1
def _truncate_seed(seed):
return seed % _MAXINT32 # truncate to fit into 32-bit integer
def get_seed(op_seed):
@ -47,12 +51,12 @@ def get_seed(op_seed):
graph_seed = ops.get_default_graph().seed
if graph_seed is not None:
if op_seed is not None:
return graph_seed, op_seed
return _truncate_seed(graph_seed), _truncate_seed(op_seed)
else:
return graph_seed, ops.get_default_graph()._last_id
return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
else:
if op_seed is not None:
return _DEFAULT_GRAPH_SEED, op_seed
return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
else:
return None, None

View File

@ -95,6 +95,7 @@ class MatMulTest(tf.test.TestCase):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64)
y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
self._testCpuMatmul(x, y)
self._testGpuMatmul(x, y)
def testHalfBasic(self):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16)
@ -135,6 +136,7 @@ class MatMulTest(tf.test.TestCase):
x = self._randMatrix(n, k, np.float64)
y = self._randMatrix(k, m, np.float64)
self._testCpuMatmul(x, y)
self._testGpuMatmul(x, y)
def testHalfRandom(self):
for _ in range(10):
@ -185,6 +187,7 @@ class MatMulTest(tf.test.TestCase):
x = self._randMatrix(k, n, np.float64)
y = self._randMatrix(m, k, np.float64)
self._testCpuMatmul(x, y, True, True)
self._testGpuMatmul(x, y, True, True)
def testHalfRandomTransposeBoth(self):
for _ in range(10):

View File

@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase):
def testSeed(self):
for use_gpu in False, True:
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
self.assertAllEqual(sx(), sy())
for seed in [345, 2**100, -2**100]:
sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
self.assertAllEqual(sx(), sy())
def testNoCSE(self):
shape = [2, 3, 4]

View File

@ -131,7 +131,7 @@ class TransposeTest(tf.test.TestCase):
self._compare_cpu_gpu(
np.arange(0, 16).reshape([1, 2, 1, 2, 1, 2, 1, 2]).astype(np.float64))
def testSComplex(self):
def testComplex64(self):
self._testBoth(np.complex(1, 2) *
np.arange(0, 21).reshape([3, 7]).astype(np.complex64))
self._testBoth(np.complex(1, 2) *
@ -140,6 +140,15 @@ class TransposeTest(tf.test.TestCase):
np.complex(1, 2) *
np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex64))
def testComplex128(self):
self._testBoth(np.complex(1, 2) *
np.arange(0, 21).reshape([3, 7]).astype(np.complex128))
self._testBoth(np.complex(1, 2) *
np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.complex128))
self._testBoth(
np.complex(1, 2) *
np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex128))
def testInt8(self):
self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int8))
self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8))

View File

@ -260,9 +260,8 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
@ -398,9 +397,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x)
for x in encoder_inputs]
@ -636,9 +634,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
with ops.device("/cpu:0"):
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
embedding = variable_scope.get_variable("embedding",
[num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None

View File

@ -25,7 +25,7 @@ import tensorflow as tf
class AdagradOptimizerTest(tf.test.TestCase):
def doTestBasic(self, use_locking=False):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -56,7 +56,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
self.doTestBasic(use_locking=True)
def testTensorLearningRate(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -80,43 +80,8 @@ class AdagradOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testFloat64(self):
with self.test_session():
opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
# compute_gradients.
values = [1.0, 3.0]
good_vars = [tf.Variable([v]) for v in values]
bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
opt.compute_gradients, bad_loss, good_vars)
bad_vars = [
tf.Variable(np.array([v], np.float64), name="bad_var")
for v in values
]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
bad_vars)
opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
# apply_gradients.
bad_grads = [
tf.constant([0.1], dtype=np.float64, name="bad_grad"),
tf.constant([0.01])
]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
opt.apply_gradients, zip(bad_grads, good_vars))
good_grads = [tf.constant([0.01]), tf.constant([0.02])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.apply_gradients, zip(good_grads, bad_vars))
opt.apply_gradients(zip(good_grads, good_vars))
def testSparseBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)
@ -145,7 +110,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
np.array([[3.0], [3.715679168701172]]), var1.eval())
def testSparseStability(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
shape = [1, 6]
var0 = tf.Variable(
@ -175,7 +140,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
0.0144573, -0.01029443]]), var0.eval())
def testSharing(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)

View File

@ -36,7 +36,7 @@ def adam_update_numpy(param, g_t, t, m, v, alpha=0.001, beta1=0.9, beta2=0.999,
class AdamOptimizerTest(tf.test.TestCase):
def testSparse(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@ -79,7 +79,7 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@ -116,7 +116,7 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testTensorLearningRate(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@ -152,41 +152,8 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testFloat64(self):
with self.test_session():
opt = tf.train.AdamOptimizer()
# compute_gradients.
values = [1.0, 3.0]
good_vars = [tf.Variable([v]) for v in values]
bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
opt.compute_gradients, bad_loss, good_vars)
bad_vars = [
tf.Variable(np.array([v], np.float64), name="bad_var")
for v in values]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
bad_vars)
opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
# apply_gradients.
bad_grads = [
tf.constant([0.1], dtype=np.float64, name="bad_grad"),
tf.constant([0.01])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
opt.apply_gradients, zip(bad_grads, good_vars))
good_grads = [tf.constant([0.01]), tf.constant([0.02])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.apply_gradients, zip(good_grads, bad_vars))
opt.apply_gradients(zip(good_grads, good_vars))
def testSharing(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0

View File

@ -25,7 +25,7 @@ import tensorflow as tf
class GradientDescentOptimizerTest(tf.test.TestCase):
def testBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -46,7 +46,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
[3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
def testTensorLearningRate(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -67,43 +67,8 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
def testFloat64(self):
with self.test_session():
opt = tf.train.GradientDescentOptimizer(3.0)
# compute_gradients.
values = [1.0, 3.0]
good_vars = [tf.Variable([v]) for v in values]
bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
opt.compute_gradients, bad_loss, good_vars)
bad_vars = [
tf.Variable(np.array([v], np.float64), name="bad_var")
for v in values
]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
bad_vars)
opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
# apply_gradients.
bad_grads = [
tf.constant([0.1], dtype=np.float64, name="bad_grad"),
tf.constant([0.01])
]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
opt.apply_gradients, zip(bad_grads, good_vars))
good_grads = [tf.constant([0.01]), tf.constant([0.02])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.apply_gradients, zip(good_grads, bad_vars))
opt.apply_gradients(zip(good_grads, good_vars))
def testGradWrtRef(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
opt = tf.train.GradientDescentOptimizer(3.0)
values = [1.0, 3.0]
@ -114,7 +79,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType([1.0], grad.eval())
def testWithGlobalStep(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
global_step = tf.Variable(0, trainable=False)
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
@ -138,7 +103,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(1, global_step.eval())
def testSparseBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import control_flow_ops
def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
@ -84,3 +85,62 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
if staircase:
p = math_ops.floor(p)
return math_ops.mul(learning_rate, math_ops.pow(decay_rate, p), name=name)
def piecewise_constant(x, boundaries, values, name=None):
""" Piecewise constant from boundaries and interval values.
Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
for steps 100001 to 110000, and 0.1 for any additional steps.
```python
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
# Later, whenever we perform an optimization step, we increment global_step.
```
Args:
x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries, and with all elements having the same type as `x`.
values: A list of `Tensor`s or float`s or `int`s that specifies the values
for the intervals defined by `boundaries`. It should have one more element
than `boundaries`, and all elements should have the same type.
name: A string. Optional name of the operation. Defaults to
'PiecewiseConstant'.
Returns:
A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
`values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
and values[-1] when `x > boundaries[-1]`.
"""
with ops.op_scope([x, boundaries, values, name],
name, 'PiecewiseConstant') as name:
x = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
if not all(b.dtype == x.dtype for b in boundaries):
raise ValueError('boundaries must have the same dtype as x.')
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
if not all(v.dtype == values[0].dtype for v in values):
raise ValueError('values must have elements all with the same dtype.')
pred_fn_pairs = {}
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x > low) & (x <= high)
pred_fn_pairs[pred] = lambda v=v: v
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)

View File

@ -72,6 +72,41 @@ class LRDecayTest(test_util.TensorFlowTestCase):
expected = .1 * 0.96**(100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testPiecewiseConstant(self):
with self.test_session():
x = variables.Variable(-999)
assign_100 = x.assign(100)
assign_105 = x.assign(105)
assign_110 = x.assign(110)
assign_120 = x.assign(120)
assign_999 = x.assign(999)
pc = learning_rate_decay.piecewise_constant(x, [100, 110, 120],
[1.0, 0.1, 0.01, 0.001])
variables.initialize_all_variables().run()
self.assertAllClose(pc.eval(), 1.0, 1e-6)
assign_100.op.run()
self.assertAllClose(pc.eval(), 1.0, 1e-6)
assign_105.op.run()
self.assertAllClose(pc.eval(), 0.1, 1e-6)
assign_110.op.run()
self.assertAllClose(pc.eval(), 0.1, 1e-6)
assign_120.op.run()
self.assertAllClose(pc.eval(), 0.01, 1e-6)
assign_999.op.run()
self.assertAllClose(pc.eval(), 0.001, 1e-6)
def testPiecewiseConstantEdgeCases(self):
with self.test_session():
with self.assertRaises(ValueError):
x_int = variables.Variable(0, dtype=variables.dtypes.int32)
boundaries, values = [-1.0, 1.0], [1, 2, 3]
pc = learning_rate_decay.piecewise_constant(x_int, boundaries, values)
with self.assertRaises(ValueError):
x = variables.Variable(0.0)
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
pc = learning_rate_decay.piecewise_constant(x, boundaries, values)
if __name__ == "__main__":
googletest.main()

View File

@ -26,7 +26,7 @@ import tensorflow as tf
class MomentumOptimizerTest(tf.test.TestCase):
def testBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -81,7 +81,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
var1.eval())
def testTensorLearningRateAndMomentum(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -136,39 +136,6 @@ class MomentumOptimizerTest(tf.test.TestCase):
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
var1.eval())
def testFloat64(self):
with self.test_session():
opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
# compute_gradients.
values = [1.0, 3.0]
good_vars = [tf.Variable([v]) for v in values]
bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
opt.compute_gradients, bad_loss, good_vars)
bad_vars = [
tf.Variable(np.array([v], np.float64), name="bad_var")
for v in values]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
bad_vars)
opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
# apply_gradients.
bad_grads = [
tf.constant([0.1], dtype=np.float64, name="bad_grad"),
tf.constant([0.01])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
opt.apply_gradients, zip(bad_grads, good_vars))
good_grads = [tf.constant([0.01]), tf.constant([0.02])]
self.assertRaisesRegexp(
ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
opt.apply_gradients, zip(good_grads, bad_vars))
opt.apply_gradients(zip(good_grads, good_vars))
def _dbParamsMom01(self):
"""Return dist-belief momentum values.
@ -222,7 +189,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
self.assertAllClose(np.array(db_out[i]), var0.eval())
def testSparse(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable(tf.zeros([4, 2], dtype=dtype))
var1 = tf.Variable(tf.constant(1.0, dtype, [4, 2]))
@ -290,7 +257,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
var1.eval()[2])
def testSharing(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)

View File

@ -376,7 +376,7 @@ class Optimizer(object):
Returns:
Valid types for loss, variables and gradients.
"""
return set([dtypes.float16, dtypes.float32])
return set([dtypes.float16, dtypes.float32, dtypes.float64])
def _create_slots(self, var_list):
"""Create all slots needed by the variables.

View File

@ -23,7 +23,7 @@ import tensorflow as tf
class OptimizerTest(tf.test.TestCase):
def testBasic(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -43,7 +43,7 @@ class OptimizerTest(tf.test.TestCase):
self.assertAllClose([-6., -5.], var1.eval())
def testAggregationMethod(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -67,7 +67,7 @@ class OptimizerTest(tf.test.TestCase):
self.assertAllClose([-6., -5.], var1.eval())
def testPrecomputedGradient(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@ -92,7 +92,7 @@ class OptimizerTest(tf.test.TestCase):
[3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval())
def testNoVariables(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype, trainable=False)
var1 = tf.Variable([3.0, 4.0], dtype=dtype, trainable=False)
@ -102,7 +102,7 @@ class OptimizerTest(tf.test.TestCase):
sgd_op.minimize(cost)
def testNoGradients(self):
for dtype in [tf.half, tf.float32]:
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)

View File

@ -138,9 +138,9 @@ string GetCudnnVersion() { return ""; }
static std::vector<string>* CreatePrimordialRpaths() {
auto rpaths = new std::vector<string>;
#if defined(__APPLE__)
rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib");
rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib");
#else
rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib64");
rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib64");
#endif
return rpaths;
}

View File

@ -234,7 +234,7 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
# TODO(opensource): we need to enable this to work around the hidden symbol
# __cudaRegisterFatBinary error. Need more investigations.
def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium",
suffix="", args=None):
suffix="", args=None, linkopts=[]):
name = name.replace(".cc", "")
native.cc_test(name="%s%s" % (name.replace("/", "_"), suffix),
size=size,
@ -243,7 +243,7 @@ def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium",
copts=tf_copts(),
data=data,
deps=deps,
linkopts=["-lpthread", "-lm"],
linkopts=["-lpthread", "-lm"] + linkopts,
linkstatic=linkstatic,
tags=tags,)
@ -254,13 +254,15 @@ def tf_cc_test_gpu(name, deps, linkstatic=0, tags=[], data=[], size="medium",
tf_cc_test(name, deps, linkstatic=linkstatic, tags=tags, data=data,
size=size, suffix=suffix, args=args)
def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,args=[]):
def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium", linkstatic=0,
args=[], linkopts=[]):
tf_cc_test(name=name,
deps=deps,
tags=tags + ["manual"],
data=data,
size=size,
linkstatic=linkstatic,
linkopts=linkopts,
args=args)
tf_cc_test(name=name,
suffix="_gpu",
@ -269,21 +271,26 @@ def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,arg
tags=tags + tf_cuda_tests_tags(),
data=data,
size=size,
linkopts=linkopts,
args=args)
# Create a cc_test for each of the tensorflow tests listed in "tests"
def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None):
def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None,
linkopts=[]):
for t in tests:
tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args)
tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args,
linkopts=linkopts)
def tf_cc_tests_gpu(tests, deps, linkstatic=0, tags=[], size="medium", args=None):
tf_cc_tests(tests, deps, linkstatic, tags=tags, size=size, args=args)
def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0, args=None):
def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0,
args=None, linkopts=[]):
for t in tests:
tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic, args=args)
tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic,
args=args, linkopts=linkopts)
def _cuda_copts():
"""Gets the appropriate set of copts for (maybe) CUDA compilation.

View File

@ -16,7 +16,12 @@
set -eux
TFDIR=$TEST_SRCDIR/tensorflow
if [ -d $TEST_SRCDIR/org_tensorflow ]; then
TFDIR=$TEST_SRCDIR/org_tensorflow/tensorflow
else
# Support 0.2.1- runfiles.
TFDIR=$TEST_SRCDIR/tensorflow
fi
DOXYGEN=doxygen
DOXYGEN_CONFIG="tf-doxy_for_md-config"
TMP_DIR=/tmp/tensorflow-docs

View File

@ -32,17 +32,38 @@ function main() {
echo "Could not find bazel-bin. Did you run from the root of the build tree?"
exit 1
fi
cp -R \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \
${TMPDIR}
if [ ! -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow ]; then
# Really old (0.2.1-) runfiles, without workspace name.
cp -R \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \
"${TMPDIR}"
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles
else
if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
# Old-style runfiles structure (--legacy_external_runfiles).
cp -R \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/{tensorflow,external} \
"${TMPDIR}"
else
# New-style runfiles structure (--nolegacy_external_runfiles).
cp -R \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \
"${TMPDIR}"
mkdir "${TMPDIR}/external"
# Note: this makes an extra copy of org_tensorflow.
cp -R \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \
"${TMPDIR}/external"
fi
RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
fi
# protobuf pip package doesn't ship with header files. Copy the headers
# over so user defined ops can be compiled.
rsync --include "*/" --include "*.h" --exclude "*" --prune-empty-dirs -a \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/google \
${TMPDIR}
rsync -a \
bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/third_party/eigen3 \
${TMPDIR}/third_party
$RUNFILES/google ${TMPDIR}
rsync -a $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party
cp tensorflow/tools/pip_package/MANIFEST.in ${TMPDIR}
cp tensorflow/tools/pip_package/README ${TMPDIR}

View File

@ -16,11 +16,16 @@
set -e -o errexit
# Prefix expected paths with ./ locally and external/reponame/ for remote repos.
# TODO(kchodorow): remove once runfiles paths are fixed, see
# https://github.com/bazelbuild/bazel/issues/848.
script_path=$(dirname $(dirname $(dirname "$0")))
script_path=${script_path:-.}
if [ -d "../org_tensorflow" ]; then
script_path="../org_tensorflow"
else
# Prefix expected paths with ./ locally and external/reponame/ for remote repos.
# TODO(kchodorow): remove once runfiles paths are fixed, see
# https://github.com/bazelbuild/bazel/issues/848.
script_path=$(dirname $(dirname $(dirname "$0")))
script_path=${script_path:-.}
fi
EXPECTED_PATHS="$script_path/util/python/python_include"\
" $script_path/util/python/python_lib"\
" $script_path/third_party/py/numpy/numpy_include"