diff --git a/WORKSPACE b/WORKSPACE index 1156a45a39e..ffebcde5541 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,3 +1,5 @@ +workspace(name = "org_tensorflow") + # Uncomment and update the paths in these entries to build the Android demo. #android_sdk_repository( # name = "androidsdk", diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 80495c9b8a1..7683cda7979 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -36,6 +36,10 @@ cc_binary( cc_test( name = "convert_graphdef_memmapped_format_test", srcs = ["convert_graphdef_memmapped_format_test.cc"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), deps = [ ":convert_graphdef_memmapped_format_lib", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index aa226886a80..81e2d085366 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1231,6 +1231,10 @@ cc_test( # higher level tests tf_cc_tests( size = "small", + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), linkstatic = tf_kernel_tests_linkstatic(), tests = glob( [ diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index f9201eeacf5..8b7f08bc8ae 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -324,6 +324,10 @@ cc_library( tf_cuda_cc_tests( size = "small", + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [], tests = [ diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index cb3b0a536df..06b74869970 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -473,6 +473,10 @@ tf_cc_test( tf_cc_test( name = "slice_op_test", size = "small", + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), deps = [ ":ops_testutil", ":ops_util", @@ -768,6 +772,10 @@ tf_cc_tests( ) tf_cc_tests( + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), tests = [ "adjust_contrast_op_test", "colorspace_op_test", @@ -1058,6 +1066,10 @@ tf_cuda_cc_test( tf_cuda_cc_test( name = "reduction_ops_test", size = "small", + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), deps = [ ":ops_testutil", ":ops_util", diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index f6420aaad8c..ce57f61d436 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -210,7 +210,7 @@ REGISTER_CPU(complex64); REGISTER_CPU(complex128); #if GOOGLE_CUDA REGISTER_GPU(float); -// REGISTER_GPU(double); +REGISTER_GPU(double); #if CUDA_VERSION >= 7050 REGISTER_GPU(Eigen::half); #endif diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc index 3febba04411..8e6cbc42700 100644 --- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc @@ -130,6 +130,10 @@ Status DoTranspose(const Device& d, const Tensor& in, internal::Transpose(d, in, perm, out); break; + case DT_COMPLEX128: + internal::Transpose(d, in, perm, out); + break; + default: return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype()); } diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index 0a15d861c2c..f44b7ea05e7 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -6366,9 +6366,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -6391,9 +6389,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -6416,8 +6412,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -6993,9 +6987,6 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7018,9 +7009,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7043,10 +7031,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 } } } @@ -7762,9 +7746,6 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7787,9 +7768,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7812,10 +7790,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 } } } @@ -7837,9 +7811,6 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7862,9 +7833,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7887,10 +7855,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 } } } @@ -7927,9 +7891,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7952,9 +7914,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -7977,8 +7937,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -10263,9 +10221,6 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -10288,9 +10243,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -10313,10 +10265,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 - type: DT_COMPLEX64 - type: DT_COMPLEX128 } } } @@ -10390,9 +10338,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -10415,9 +10361,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -10440,8 +10384,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -16137,9 +16079,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -16162,9 +16102,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -16187,8 +16125,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -17916,9 +17852,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -17941,9 +17875,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -17966,8 +17898,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -18089,9 +18019,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -18114,9 +18042,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -18139,8 +18065,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -20897,9 +20821,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -20922,9 +20844,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -20947,8 +20867,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } @@ -21696,9 +21614,7 @@ op { list { type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -21721,9 +21637,7 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 type: DT_COMPLEX64 - type: DT_INT64 } } } @@ -21746,8 +21660,6 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_INT64 type: DT_COMPLEX64 type: DT_COMPLEX128 } diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index fdb490df9e2..15039911d3d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -136,6 +136,14 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229] Input("x: T").Output("y: T").Attr( \ "T: {half, float, double, int32, int64, complex64, complex128}") +#define UNARY_REAL() \ + Input("x: T").Output("y: T").Attr( \ + "T: {half, float, double}") + +#define UNARY_COMPLEX() \ + Input("x: T").Output("y: T").Attr( \ + "T: {half, float, double, complex64, complex128}") + REGISTER_OP("Neg") .UNARY() .Doc(R"doc( @@ -158,65 +166,65 @@ I.e., \\(y = x * x = x^2\\). )doc"); REGISTER_OP("Sqrt") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes square root of x element-wise. I.e., \\(y = \sqrt{x} = x^{1/2}\\). )doc"); REGISTER_OP("Rsqrt") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes reciprocal of square root of x element-wise. I.e., \\(y = 1 / \sqrt{x}\\). )doc"); REGISTER_OP("Exp") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes exponential of x element-wise. \\(y = e^x\\). )doc"); REGISTER_OP("Log") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes natural logarithm of x element-wise. I.e., \\(y = \log_e x\\). )doc"); REGISTER_OP("Tanh") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes hyperbolic tangent of `x` element-wise. )doc"); REGISTER_OP("Lgamma") - .UNARY() + .UNARY_REAL() .Doc(R"doc( Computes the log of the absolute value of `Gamma(x)` element-wise. )doc"); REGISTER_OP("Digamma") - .UNARY() + .UNARY_REAL() .Doc(R"doc( Computes Psi, the derivative of Lgamma (the log of the absolute value of `Gamma(x)`), element-wise. )doc"); REGISTER_OP("Erf") - .UNARY() + .UNARY_REAL() .Doc(R"doc( Computes the Gauss error function of `x` element-wise. )doc"); REGISTER_OP("Erfc") - .UNARY() + .UNARY_REAL() .Doc(R"doc( Computes the complementary error function of `x` element-wise. )doc"); REGISTER_OP("Sigmoid") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes sigmoid of `x` element-wise. @@ -224,18 +232,20 @@ Specifically, `y = 1 / (1 + exp(-x))`. )doc"); REGISTER_OP("Sin") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes sin of x element-wise. )doc"); REGISTER_OP("Cos") - .UNARY() + .UNARY_COMPLEX() .Doc(R"doc( Computes cos of x element-wise. )doc"); #undef UNARY +#undef UNARY_REAL +#undef UNARY_COMPLEX REGISTER_OP("IsNan") .Input("x: T") diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 9ed8dbf13ec..8f5f9c635cd 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -521,13 +521,13 @@ def get_random_distorted_bottlenecks( ground_truths = [] for unused_i in range(how_many): label_index = random.randrange(class_count) - label_name = image_lists.keys()[label_index] + label_name = list(image_lists.keys())[label_index] image_index = random.randrange(65536) image_path = get_image_path(image_lists, label_name, image_index, image_dir, category) if not gfile.Exists(image_path): tf.logging.fatal('File does not exist %s', image_path) - jpeg_data = gfile.FastGFile(image_path, 'r').read() + jpeg_data = gfile.FastGFile(image_path, 'rb').read() # Note that we materialize the distorted_image_data as a numpy array before # sending running inference on the image. This involves 2 memory copies and # might be optimized in other implementations. @@ -616,7 +616,7 @@ def add_input_distortions(flip_left_right, random_crop, random_scale, """ jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') - decoded_image = tf.image.decode_jpeg(jpeg_data) + decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH) decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32) decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) margin_scale = 1.0 + (random_crop / 100.0) diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md index 008174f9d6f..1aaf39bd50b 100644 --- a/tensorflow/g3doc/api_docs/python/constant_op.md +++ b/tensorflow/g3doc/api_docs/python/constant_op.md @@ -60,7 +60,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]] * `tensor`: A `Tensor`. * `dtype`: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`. * `name`: A name for the operation (optional). @@ -119,7 +119,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]] * `tensor`: A `Tensor`. * `dtype`: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`. + `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64` or `complex128`. * `name`: A name for the operation (optional). diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index e92d5b94108..9405ed74d1e 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -163,7 +163,7 @@ case where both types are quantized. * `value`: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`, - `int16`, `int8`, or `complex64`. + `int16`, `int8`, `complex64` or `complex128`. * `bias`: A 1-D `Tensor` with size matching the last dimension of `value`. Must be the same type as `value` unless `value` is a quantized type, in which case a different quantized type may be used. @@ -186,7 +186,7 @@ Specifically, `y = 1 / (1 + exp(-x))`. ##### Args: -* `x`: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, +* `x`: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`, or `qint32`. * `name`: A name for the operation (optional). @@ -205,7 +205,7 @@ Computes hyperbolic tangent of `x` element-wise. ##### Args: -* `x`: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`, +* `x`: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`, or `qint32`. * `name`: A name for the operation (optional). diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index bdb94cda102..0d7b851bac9 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -658,7 +658,7 @@ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_packag mkdir _python_build cd _python_build -ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/* . +ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow* . ln -s ../tensorflow/tools/pip_package/* . python setup.py develop ``` diff --git a/tensorflow/g3doc/tutorials/pdes/index.md b/tensorflow/g3doc/tutorials/pdes/index.md index ca240347598..9bc4340285e 100755 --- a/tensorflow/g3doc/tutorials/pdes/index.md +++ b/tensorflow/g3doc/tutorials/pdes/index.md @@ -19,7 +19,7 @@ import numpy as np #Imports for visualization import PIL.Image -from cStringIO import StringIO +from io import BytesIO from IPython.display import clear_output, Image, display ``` @@ -30,8 +30,9 @@ def DisplayArray(a, fmt='jpeg', rng=[0,1]): """Display an array as a picture.""" a = (a - rng[0])/float(rng[1] - rng[0])*255 a = np.uint8(np.clip(a, 0, 255)) - f = StringIO() + f = BytesIO() PIL.Image.fromarray(a).save(f, fmt) + clear_output(wait = True) display(Image(data=f.getvalue())) ``` @@ -132,10 +133,7 @@ tf.initialize_all_variables().run() for i in range(1000): # Step simulation step.run({eps: 0.03, damping: 0.04}) - # Visualize every 50 steps - if i % 50 == 0: - clear_output() - DisplayArray(U.eval(), rng=[-0.1, 0.1]) + DisplayArray(U.eval(), rng=[-0.1, 0.1]) ``` ![jpeg](../../images/pde_output_2.jpg) diff --git a/tensorflow/models/rnn/translate/seq2seq_model.py b/tensorflow/models/rnn/translate/seq2seq_model.py index b0d8ff43db3..a921f28c06f 100644 --- a/tensorflow/models/rnn/translate/seq2seq_model.py +++ b/tensorflow/models/rnn/translate/seq2seq_model.py @@ -83,17 +83,15 @@ class Seq2SeqModel(object): softmax_loss_function = None # Sampled softmax only makes sense if we sample less than vocabulary size. if num_samples > 0 and num_samples < self.target_vocab_size: - with tf.device("/cpu:0"): - w = tf.get_variable("proj_w", [size, self.target_vocab_size]) - w_t = tf.transpose(w) - b = tf.get_variable("proj_b", [self.target_vocab_size]) + w = tf.get_variable("proj_w", [size, self.target_vocab_size]) + w_t = tf.transpose(w) + b = tf.get_variable("proj_b", [self.target_vocab_size]) output_projection = (w, b) def sampled_loss(inputs, labels): - with tf.device("/cpu:0"): - labels = tf.reshape(labels, [-1, 1]) - return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples, - self.target_vocab_size) + labels = tf.reshape(labels, [-1, 1]) + return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples, + self.target_vocab_size) softmax_loss_function = sampled_loss # Create the internal multi-layer cell for our RNN. diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py index b70f626a9ee..9f503b8f29c 100644 --- a/tensorflow/python/framework/random_seed.py +++ b/tensorflow/python/framework/random_seed.py @@ -24,6 +24,10 @@ from tensorflow.python.framework import ops _DEFAULT_GRAPH_SEED = 87654321 +_MAXINT32 = 2**31 - 1 + +def _truncate_seed(seed): + return seed % _MAXINT32 # truncate to fit into 32-bit integer def get_seed(op_seed): @@ -47,12 +51,12 @@ def get_seed(op_seed): graph_seed = ops.get_default_graph().seed if graph_seed is not None: if op_seed is not None: - return graph_seed, op_seed + return _truncate_seed(graph_seed), _truncate_seed(op_seed) else: - return graph_seed, ops.get_default_graph()._last_id + return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id) else: if op_seed is not None: - return _DEFAULT_GRAPH_SEED, op_seed + return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed) else: return None, None diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py index 25553097a6c..6c817a5da80 100644 --- a/tensorflow/python/kernel_tests/matmul_op_test.py +++ b/tensorflow/python/kernel_tests/matmul_op_test.py @@ -95,6 +95,7 @@ class MatMulTest(tf.test.TestCase): x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64) y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64) self._testCpuMatmul(x, y) + self._testGpuMatmul(x, y) def testHalfBasic(self): x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16) @@ -135,6 +136,7 @@ class MatMulTest(tf.test.TestCase): x = self._randMatrix(n, k, np.float64) y = self._randMatrix(k, m, np.float64) self._testCpuMatmul(x, y) + self._testGpuMatmul(x, y) def testHalfRandom(self): for _ in range(10): @@ -185,6 +187,7 @@ class MatMulTest(tf.test.TestCase): x = self._randMatrix(k, n, np.float64) y = self._randMatrix(m, k, np.float64) self._testCpuMatmul(x, y, True, True) + self._testGpuMatmul(x, y, True, True) def testHalfRandomTransposeBoth(self): for _ in range(10): diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py index 45b61be0c31..f4ed26b1e25 100644 --- a/tensorflow/python/kernel_tests/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random_ops_test.py @@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase): def testSeed(self): for use_gpu in False, True: for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64: - sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345) - sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345) - self.assertAllEqual(sx(), sy()) + for seed in [345, 2**100, -2**100]: + sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed) + sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed) + self.assertAllEqual(sx(), sy()) def testNoCSE(self): shape = [2, 3, 4] diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 7881fe152d8..e4ce40303ee 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -131,7 +131,7 @@ class TransposeTest(tf.test.TestCase): self._compare_cpu_gpu( np.arange(0, 16).reshape([1, 2, 1, 2, 1, 2, 1, 2]).astype(np.float64)) - def testSComplex(self): + def testComplex64(self): self._testBoth(np.complex(1, 2) * np.arange(0, 21).reshape([3, 7]).astype(np.complex64)) self._testBoth(np.complex(1, 2) * @@ -140,6 +140,15 @@ class TransposeTest(tf.test.TestCase): np.complex(1, 2) * np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex64)) + def testComplex128(self): + self._testBoth(np.complex(1, 2) * + np.arange(0, 21).reshape([3, 7]).astype(np.complex128)) + self._testBoth(np.complex(1, 2) * + np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.complex128)) + self._testBoth( + np.complex(1, 2) * + np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex128)) + def testInt8(self): self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int8)) self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8)) diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py index e920c95dec1..cb1773c7f16 100644 --- a/tensorflow/python/ops/seq2seq.py +++ b/tensorflow/python/ops/seq2seq.py @@ -260,9 +260,8 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols, proj_biases.get_shape().assert_is_compatible_with([num_symbols]) with variable_scope.variable_scope(scope or "embedding_rnn_decoder"): - with ops.device("/cpu:0"): - embedding = variable_scope.get_variable("embedding", - [num_symbols, embedding_size]) + embedding = variable_scope.get_variable("embedding", + [num_symbols, embedding_size]) loop_function = _extract_argmax_and_embed( embedding, output_projection, update_embedding_for_previous) if feed_previous else None @@ -398,9 +397,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, proj_biases.get_shape().assert_is_compatible_with([num_symbols]) with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"): - with ops.device("/cpu:0"): - embedding = variable_scope.get_variable("embedding", - [num_symbols, embedding_size]) + embedding = variable_scope.get_variable("embedding", + [num_symbols, embedding_size]) emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x) for x in encoder_inputs] @@ -636,9 +634,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states, proj_biases.get_shape().assert_is_compatible_with([num_symbols]) with variable_scope.variable_scope(scope or "embedding_attention_decoder"): - with ops.device("/cpu:0"): - embedding = variable_scope.get_variable("embedding", - [num_symbols, embedding_size]) + embedding = variable_scope.get_variable("embedding", + [num_symbols, embedding_size]) loop_function = _extract_argmax_and_embed( embedding, output_projection, update_embedding_for_previous) if feed_previous else None diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index c5e118bdc71..4125a7aa3c4 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -25,7 +25,7 @@ import tensorflow as tf class AdagradOptimizerTest(tf.test.TestCase): def doTestBasic(self, use_locking=False): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -56,7 +56,7 @@ class AdagradOptimizerTest(tf.test.TestCase): self.doTestBasic(use_locking=True) def testTensorLearningRate(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -80,43 +80,8 @@ class AdagradOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType( np.array([2.715679168701172, 3.715679168701172]), var1.eval()) - def testFloat64(self): - with self.test_session(): - opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1) - - # compute_gradients. - values = [1.0, 3.0] - good_vars = [tf.Variable([v]) for v in values] - bad_loss = tf.constant(2.0, tf.float64, name="bad_loss") - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32", - opt.compute_gradients, bad_loss, good_vars) - bad_vars = [ - tf.Variable(np.array([v], np.float64), name="bad_var") - for v in values - ] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32), - bad_vars) - opt.compute_gradients(good_vars[0] + good_vars[1], good_vars) - - # apply_gradients. - bad_grads = [ - tf.constant([0.1], dtype=np.float64, name="bad_grad"), - tf.constant([0.01]) - ] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32", - opt.apply_gradients, zip(bad_grads, good_vars)) - good_grads = [tf.constant([0.01]), tf.constant([0.02])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.apply_gradients, zip(good_grads, bad_vars)) - opt.apply_gradients(zip(good_grads, good_vars)) - def testSparseBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([[1.0], [2.0]], dtype=dtype) var1 = tf.Variable([[3.0], [4.0]], dtype=dtype) @@ -145,7 +110,7 @@ class AdagradOptimizerTest(tf.test.TestCase): np.array([[3.0], [3.715679168701172]]), var1.eval()) def testSparseStability(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): shape = [1, 6] var0 = tf.Variable( @@ -175,7 +140,7 @@ class AdagradOptimizerTest(tf.test.TestCase): 0.0144573, -0.01029443]]), var0.eval()) def testSharing(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py index ce6ec99f846..df97b1be508 100644 --- a/tensorflow/python/training/adam_test.py +++ b/tensorflow/python/training/adam_test.py @@ -36,7 +36,7 @@ def adam_update_numpy(param, g_t, t, m, v, alpha=0.001, beta1=0.9, beta2=0.999, class AdamOptimizerTest(tf.test.TestCase): def testSparse(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -79,7 +79,7 @@ class AdamOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType(var1_np, var1.eval()) def testBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -116,7 +116,7 @@ class AdamOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType(var1_np, var1.eval()) def testTensorLearningRate(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -152,41 +152,8 @@ class AdamOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) - def testFloat64(self): - with self.test_session(): - opt = tf.train.AdamOptimizer() - - # compute_gradients. - values = [1.0, 3.0] - good_vars = [tf.Variable([v]) for v in values] - bad_loss = tf.constant(2.0, tf.float64, name="bad_loss") - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32", - opt.compute_gradients, bad_loss, good_vars) - bad_vars = [ - tf.Variable(np.array([v], np.float64), name="bad_var") - for v in values] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32), - bad_vars) - opt.compute_gradients(good_vars[0] + good_vars[1], good_vars) - - # apply_gradients. - bad_grads = [ - tf.constant([0.1], dtype=np.float64, name="bad_grad"), - tf.constant([0.01])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32", - opt.apply_gradients, zip(bad_grads, good_vars)) - good_grads = [tf.constant([0.01]), tf.constant([0.02])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.apply_gradients, zip(good_grads, bad_vars)) - opt.apply_gradients(zip(good_grads, good_vars)) - def testSharing(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py index 51619630321..26e1ae7dee9 100644 --- a/tensorflow/python/training/gradient_descent_test.py +++ b/tensorflow/python/training/gradient_descent_test.py @@ -25,7 +25,7 @@ import tensorflow as tf class GradientDescentOptimizerTest(tf.test.TestCase): def testBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -46,7 +46,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase): [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) def testTensorLearningRate(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -67,43 +67,8 @@ class GradientDescentOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType( [3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval()) - def testFloat64(self): - with self.test_session(): - opt = tf.train.GradientDescentOptimizer(3.0) - - # compute_gradients. - values = [1.0, 3.0] - good_vars = [tf.Variable([v]) for v in values] - bad_loss = tf.constant(2.0, tf.float64, name="bad_loss") - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32", - opt.compute_gradients, bad_loss, good_vars) - bad_vars = [ - tf.Variable(np.array([v], np.float64), name="bad_var") - for v in values - ] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32), - bad_vars) - opt.compute_gradients(good_vars[0] + good_vars[1], good_vars) - - # apply_gradients. - bad_grads = [ - tf.constant([0.1], dtype=np.float64, name="bad_grad"), - tf.constant([0.01]) - ] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32", - opt.apply_gradients, zip(bad_grads, good_vars)) - good_grads = [tf.constant([0.01]), tf.constant([0.02])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.apply_gradients, zip(good_grads, bad_vars)) - opt.apply_gradients(zip(good_grads, good_vars)) - def testGradWrtRef(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): opt = tf.train.GradientDescentOptimizer(3.0) values = [1.0, 3.0] @@ -114,7 +79,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType([1.0], grad.eval()) def testWithGlobalStep(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): global_step = tf.Variable(0, trainable=False) var0 = tf.Variable([1.0, 2.0], dtype=dtype) @@ -138,7 +103,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase): self.assertAllCloseAccordingToType(1, global_step.eval()) def testSparseBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([[1.0], [2.0]], dtype=dtype) var1 = tf.Variable([[3.0], [4.0]], dtype=dtype) diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index ab48d347827..203399eab3c 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import control_flow_ops def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, @@ -84,3 +85,62 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, if staircase: p = math_ops.floor(p) return math_ops.mul(learning_rate, math_ops.pow(decay_rate, p), name=name) + + +def piecewise_constant(x, boundaries, values, name=None): + """ Piecewise constant from boundaries and interval values. + + Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5 + for steps 100001 to 110000, and 0.1 for any additional steps. + + ```python + global_step = tf.Variable(0, trainable=False) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) + + # Later, whenever we perform an optimization step, we increment global_step. + ``` + + Args: + x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, + `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. + boundaries: A list of `Tensor`s or `int`s or `float`s with strictly + increasing entries, and with all elements having the same type as `x`. + values: A list of `Tensor`s or float`s or `int`s that specifies the values + for the intervals defined by `boundaries`. It should have one more element + than `boundaries`, and all elements should have the same type. + name: A string. Optional name of the operation. Defaults to + 'PiecewiseConstant'. + + Returns: + A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, + `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., + and values[-1] when `x > boundaries[-1]`. + """ + + with ops.op_scope([x, boundaries, values, name], + name, 'PiecewiseConstant') as name: + x = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + boundaries = ops.convert_n_to_tensor(boundaries) + if not all(b.dtype == x.dtype for b in boundaries): + raise ValueError('boundaries must have the same dtype as x.') + # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing. + values = ops.convert_n_to_tensor(values) + if not all(v.dtype == values[0].dtype for v in values): + raise ValueError('values must have elements all with the same dtype.') + + pred_fn_pairs = {} + pred_fn_pairs[x <= boundaries[0]] = lambda: values[0] + pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1] + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x > low) & (x <= high) + pred_fn_pairs[pred] = lambda v=v: v + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index 21ea03826cc..6fabd58fe31 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -72,6 +72,41 @@ class LRDecayTest(test_util.TensorFlowTestCase): expected = .1 * 0.96**(100 // 3) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) + def testPiecewiseConstant(self): + with self.test_session(): + x = variables.Variable(-999) + assign_100 = x.assign(100) + assign_105 = x.assign(105) + assign_110 = x.assign(110) + assign_120 = x.assign(120) + assign_999 = x.assign(999) + pc = learning_rate_decay.piecewise_constant(x, [100, 110, 120], + [1.0, 0.1, 0.01, 0.001]) + + variables.initialize_all_variables().run() + self.assertAllClose(pc.eval(), 1.0, 1e-6) + assign_100.op.run() + self.assertAllClose(pc.eval(), 1.0, 1e-6) + assign_105.op.run() + self.assertAllClose(pc.eval(), 0.1, 1e-6) + assign_110.op.run() + self.assertAllClose(pc.eval(), 0.1, 1e-6) + assign_120.op.run() + self.assertAllClose(pc.eval(), 0.01, 1e-6) + assign_999.op.run() + self.assertAllClose(pc.eval(), 0.001, 1e-6) + + def testPiecewiseConstantEdgeCases(self): + with self.test_session(): + with self.assertRaises(ValueError): + x_int = variables.Variable(0, dtype=variables.dtypes.int32) + boundaries, values = [-1.0, 1.0], [1, 2, 3] + pc = learning_rate_decay.piecewise_constant(x_int, boundaries, values) + with self.assertRaises(ValueError): + x = variables.Variable(0.0) + boundaries, values = [-1.0, 1.0], [1.0, 2, 3] + pc = learning_rate_decay.piecewise_constant(x, boundaries, values) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py index 72e0eed4c4a..88468f56a82 100644 --- a/tensorflow/python/training/momentum_test.py +++ b/tensorflow/python/training/momentum_test.py @@ -26,7 +26,7 @@ import tensorflow as tf class MomentumOptimizerTest(tf.test.TestCase): def testBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -81,7 +81,7 @@ class MomentumOptimizerTest(tf.test.TestCase): var1.eval()) def testTensorLearningRateAndMomentum(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -136,39 +136,6 @@ class MomentumOptimizerTest(tf.test.TestCase): 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]), var1.eval()) - def testFloat64(self): - with self.test_session(): - opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9) - - # compute_gradients. - values = [1.0, 3.0] - good_vars = [tf.Variable([v]) for v in values] - bad_loss = tf.constant(2.0, tf.float64, name="bad_loss") - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32", - opt.compute_gradients, bad_loss, good_vars) - bad_vars = [ - tf.Variable(np.array([v], np.float64), name="bad_var") - for v in values] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32), - bad_vars) - opt.compute_gradients(good_vars[0] + good_vars[1], good_vars) - - # apply_gradients. - bad_grads = [ - tf.constant([0.1], dtype=np.float64, name="bad_grad"), - tf.constant([0.01])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32", - opt.apply_gradients, zip(bad_grads, good_vars)) - good_grads = [tf.constant([0.01]), tf.constant([0.02])] - self.assertRaisesRegexp( - ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32", - opt.apply_gradients, zip(good_grads, bad_vars)) - opt.apply_gradients(zip(good_grads, good_vars)) - def _dbParamsMom01(self): """Return dist-belief momentum values. @@ -222,7 +189,7 @@ class MomentumOptimizerTest(tf.test.TestCase): self.assertAllClose(np.array(db_out[i]), var0.eval()) def testSparse(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable(tf.zeros([4, 2], dtype=dtype)) var1 = tf.Variable(tf.constant(1.0, dtype, [4, 2])) @@ -290,7 +257,7 @@ class MomentumOptimizerTest(tf.test.TestCase): var1.eval()[2]) def testSharing(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index a8b41be9239..623c76e18b0 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -376,7 +376,7 @@ class Optimizer(object): Returns: Valid types for loss, variables and gradients. """ - return set([dtypes.float16, dtypes.float32]) + return set([dtypes.float16, dtypes.float32, dtypes.float64]) def _create_slots(self, var_list): """Create all slots needed by the variables. diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py index 54d400a51c0..f87a207c60e 100644 --- a/tensorflow/python/training/optimizer_test.py +++ b/tensorflow/python/training/optimizer_test.py @@ -23,7 +23,7 @@ import tensorflow as tf class OptimizerTest(tf.test.TestCase): def testBasic(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -43,7 +43,7 @@ class OptimizerTest(tf.test.TestCase): self.assertAllClose([-6., -5.], var1.eval()) def testAggregationMethod(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -67,7 +67,7 @@ class OptimizerTest(tf.test.TestCase): self.assertAllClose([-6., -5.], var1.eval()) def testPrecomputedGradient(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) @@ -92,7 +92,7 @@ class OptimizerTest(tf.test.TestCase): [3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval()) def testNoVariables(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype, trainable=False) var1 = tf.Variable([3.0, 4.0], dtype=dtype, trainable=False) @@ -102,7 +102,7 @@ class OptimizerTest(tf.test.TestCase): sgd_op.minimize(cost) def testNoGradients(self): - for dtype in [tf.half, tf.float32]: + for dtype in [tf.half, tf.float32, tf.float64]: with self.test_session(): var0 = tf.Variable([1.0, 2.0], dtype=dtype) var1 = tf.Variable([3.0, 4.0], dtype=dtype) diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index bf7faef209b..0f7d0eaeb33 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -138,9 +138,9 @@ string GetCudnnVersion() { return ""; } static std::vector* CreatePrimordialRpaths() { auto rpaths = new std::vector; #if defined(__APPLE__) - rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib"); + rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib"); #else - rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib64"); + rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib64"); #endif return rpaths; } diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 1e73e00a3f4..4eb5619ecd7 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -234,7 +234,7 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[], # TODO(opensource): we need to enable this to work around the hidden symbol # __cudaRegisterFatBinary error. Need more investigations. def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium", - suffix="", args=None): + suffix="", args=None, linkopts=[]): name = name.replace(".cc", "") native.cc_test(name="%s%s" % (name.replace("/", "_"), suffix), size=size, @@ -243,7 +243,7 @@ def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium", copts=tf_copts(), data=data, deps=deps, - linkopts=["-lpthread", "-lm"], + linkopts=["-lpthread", "-lm"] + linkopts, linkstatic=linkstatic, tags=tags,) @@ -254,13 +254,15 @@ def tf_cc_test_gpu(name, deps, linkstatic=0, tags=[], data=[], size="medium", tf_cc_test(name, deps, linkstatic=linkstatic, tags=tags, data=data, size=size, suffix=suffix, args=args) -def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,args=[]): +def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium", linkstatic=0, + args=[], linkopts=[]): tf_cc_test(name=name, deps=deps, tags=tags + ["manual"], data=data, size=size, linkstatic=linkstatic, + linkopts=linkopts, args=args) tf_cc_test(name=name, suffix="_gpu", @@ -269,21 +271,26 @@ def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,arg tags=tags + tf_cuda_tests_tags(), data=data, size=size, + linkopts=linkopts, args=args) # Create a cc_test for each of the tensorflow tests listed in "tests" -def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None): +def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None, + linkopts=[]): for t in tests: - tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args) + tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args, + linkopts=linkopts) def tf_cc_tests_gpu(tests, deps, linkstatic=0, tags=[], size="medium", args=None): tf_cc_tests(tests, deps, linkstatic, tags=tags, size=size, args=args) -def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0, args=None): +def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0, + args=None, linkopts=[]): for t in tests: - tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic, args=args) + tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic, + args=args, linkopts=linkopts) def _cuda_copts(): """Gets the appropriate set of copts for (maybe) CUDA compilation. diff --git a/tensorflow/tools/docs/gen_docs_test.sh b/tensorflow/tools/docs/gen_docs_test.sh index 9375784dc23..70236383222 100755 --- a/tensorflow/tools/docs/gen_docs_test.sh +++ b/tensorflow/tools/docs/gen_docs_test.sh @@ -16,7 +16,12 @@ set -eux -TFDIR=$TEST_SRCDIR/tensorflow +if [ -d $TEST_SRCDIR/org_tensorflow ]; then + TFDIR=$TEST_SRCDIR/org_tensorflow/tensorflow +else + # Support 0.2.1- runfiles. + TFDIR=$TEST_SRCDIR/tensorflow +fi DOXYGEN=doxygen DOXYGEN_CONFIG="tf-doxy_for_md-config" TMP_DIR=/tmp/tensorflow-docs diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 1ae6926b676..7f123937e8b 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -32,17 +32,38 @@ function main() { echo "Could not find bazel-bin. Did you run from the root of the build tree?" exit 1 fi - cp -R \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \ - ${TMPDIR} + + if [ ! -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow ]; then + # Really old (0.2.1-) runfiles, without workspace name. + cp -R \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \ + "${TMPDIR}" + RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles + else + if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then + # Old-style runfiles structure (--legacy_external_runfiles). + cp -R \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/{tensorflow,external} \ + "${TMPDIR}" + else + # New-style runfiles structure (--nolegacy_external_runfiles). + cp -R \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \ + "${TMPDIR}" + mkdir "${TMPDIR}/external" + # Note: this makes an extra copy of org_tensorflow. + cp -R \ + bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \ + "${TMPDIR}/external" + fi + RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow + fi + # protobuf pip package doesn't ship with header files. Copy the headers # over so user defined ops can be compiled. rsync --include "*/" --include "*.h" --exclude "*" --prune-empty-dirs -a \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/google \ - ${TMPDIR} - rsync -a \ - bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/third_party/eigen3 \ - ${TMPDIR}/third_party + $RUNFILES/google ${TMPDIR} + rsync -a $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party cp tensorflow/tools/pip_package/MANIFEST.in ${TMPDIR} cp tensorflow/tools/pip_package/README ${TMPDIR} diff --git a/util/python/python_config.sh b/util/python/python_config.sh index 83e38566906..7554765003d 100755 --- a/util/python/python_config.sh +++ b/util/python/python_config.sh @@ -16,11 +16,16 @@ set -e -o errexit -# Prefix expected paths with ./ locally and external/reponame/ for remote repos. -# TODO(kchodorow): remove once runfiles paths are fixed, see -# https://github.com/bazelbuild/bazel/issues/848. -script_path=$(dirname $(dirname $(dirname "$0"))) -script_path=${script_path:-.} +if [ -d "../org_tensorflow" ]; then + script_path="../org_tensorflow" +else + # Prefix expected paths with ./ locally and external/reponame/ for remote repos. + # TODO(kchodorow): remove once runfiles paths are fixed, see + # https://github.com/bazelbuild/bazel/issues/848. + script_path=$(dirname $(dirname $(dirname "$0"))) + script_path=${script_path:-.} +fi + EXPECTED_PATHS="$script_path/util/python/python_include"\ " $script_path/util/python/python_lib"\ " $script_path/third_party/py/numpy/numpy_include"