From e2b70c0f80aee73ea4ff6679ebf2c7c1f9b5d4b4 Mon Sep 17 00:00:00 2001 From: Matej Rizman Date: Tue, 28 Jan 2020 02:22:48 -0800 Subject: [PATCH] Change the default dtype for SobolSample op to float32. PiperOrigin-RevId: 291893007 Change-Id: Ica87ba033b5dcb7796d3b61f5a34eea3e026cfed --- tensorflow/core/ops/compat/ops_history_v1/SobolSample.pbtxt | 2 +- tensorflow/core/ops/math_ops.cc | 2 +- tensorflow/python/ops/math_ops.py | 6 +++--- tensorflow/python/ops/sobol_ops_test.py | 5 +++++ tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt | 2 +- tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt | 2 +- tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 8 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/ops/compat/ops_history_v1/SobolSample.pbtxt b/tensorflow/core/ops/compat/ops_history_v1/SobolSample.pbtxt index 182b1e5becf..4fe7c45282a 100644 --- a/tensorflow/core/ops/compat/ops_history_v1/SobolSample.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v1/SobolSample.pbtxt @@ -20,7 +20,7 @@ op { name: "dtype" type: "type" default_value { - type: DT_DOUBLE + type: DT_FLOAT } allowed_values { list { diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index c0bf0eb6bf2..7ba946faf92 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1918,7 +1918,7 @@ REGISTER_OP("SobolSample") .Input("dim: int32") .Input("num_results: int32") .Input("skip: int32") - .Attr("dtype: {float, double} = DT_DOUBLE") + .Attr("dtype: {float, double} = DT_FLOAT") .Output("samples: dtype") .SetShapeFn([](shape_inference::InferenceContext* c) { ShapeHandle unused; diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index ee7b1194610..83b8d2f376d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -4564,7 +4564,7 @@ def exp(x, name=None): @tf_export("math.sobol_sample") -def sobol_sample(dim, num_results, skip=0, dtype=None, name=None): +def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None): """Generates points from the Sobol sequence. Creates a Sobol sequence with `num_results` samples. Each sample has dimension @@ -4576,8 +4576,8 @@ def sobol_sample(dim, num_results, skip=0, dtype=None, name=None): points to return in the output. skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of initial points of the Sobol sequence to skip. Default value is 0. - dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`. - Default value is determined by the C++ kernel. + dtype: (Optional) The `tf.Dtype` of the sample. One of: `tf.float32` or + `tf.float64`. Defaults to `tf.float32`. name: (Optional) Python `str` name prefixed to ops created by this function. Returns: diff --git a/tensorflow/python/ops/sobol_ops_test.py b/tensorflow/python/ops/sobol_ops_test.py index 2f99a5e0db3..cf717343c3b 100644 --- a/tensorflow/python/ops/sobol_ops_test.py +++ b/tensorflow/python/ops/sobol_ops_test.py @@ -124,6 +124,11 @@ class SobolSampleOpTest(test_util.TensorFlowTestCase): s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32) self.assertAllEqual([100, 10], self.evaluate(s).shape) + def test_default_dtype(self): + # Create an op without specifying the dtype. Dtype should be float32 in + # this case. + s = math_ops.sobol_sample(10, 100) + self.assertEqual(dtypes.float32, s.dtype) if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt index e4ab4e8f88a..c1448a85833 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -422,7 +422,7 @@ tf_module { } member_method { name: "sobol_sample" - argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'None\'], " } member_method { name: "softmax" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2598a6807a0..c43f311d5cb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3846,7 +3846,7 @@ tf_module { } member_method { name: "SobolSample" - argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "Softmax" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index d68ca9759d4..075946b64fd 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -422,7 +422,7 @@ tf_module { } member_method { name: "sobol_sample" - argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'None\'], " } member_method { name: "softmax" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2598a6807a0..c43f311d5cb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3846,7 +3846,7 @@ tf_module { } member_method { name: "SobolSample" - argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " } member_method { name: "Softmax"