Change the default dtype for SobolSample op to float32.

PiperOrigin-RevId: 291893007
Change-Id: Ica87ba033b5dcb7796d3b61f5a34eea3e026cfed
This commit is contained in:
Matej Rizman 2020-01-28 02:22:48 -08:00 committed by TensorFlower Gardener
parent 46644c6c58
commit e2b70c0f80
8 changed files with 14 additions and 9 deletions

View File

@ -20,7 +20,7 @@ op {
name: "dtype" name: "dtype"
type: "type" type: "type"
default_value { default_value {
type: DT_DOUBLE type: DT_FLOAT
} }
allowed_values { allowed_values {
list { list {

View File

@ -1918,7 +1918,7 @@ REGISTER_OP("SobolSample")
.Input("dim: int32") .Input("dim: int32")
.Input("num_results: int32") .Input("num_results: int32")
.Input("skip: int32") .Input("skip: int32")
.Attr("dtype: {float, double} = DT_DOUBLE") .Attr("dtype: {float, double} = DT_FLOAT")
.Output("samples: dtype") .Output("samples: dtype")
.SetShapeFn([](shape_inference::InferenceContext* c) { .SetShapeFn([](shape_inference::InferenceContext* c) {
ShapeHandle unused; ShapeHandle unused;

View File

@ -4564,7 +4564,7 @@ def exp(x, name=None):
@tf_export("math.sobol_sample") @tf_export("math.sobol_sample")
def sobol_sample(dim, num_results, skip=0, dtype=None, name=None): def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
"""Generates points from the Sobol sequence. """Generates points from the Sobol sequence.
Creates a Sobol sequence with `num_results` samples. Each sample has dimension Creates a Sobol sequence with `num_results` samples. Each sample has dimension
@ -4576,8 +4576,8 @@ def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
points to return in the output. points to return in the output.
skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
initial points of the Sobol sequence to skip. Default value is 0. initial points of the Sobol sequence to skip. Default value is 0.
dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`. dtype: (Optional) The `tf.Dtype` of the sample. One of: `tf.float32` or
Default value is determined by the C++ kernel. `tf.float64`. Defaults to `tf.float32`.
name: (Optional) Python `str` name prefixed to ops created by this function. name: (Optional) Python `str` name prefixed to ops created by this function.
Returns: Returns:

View File

@ -124,6 +124,11 @@ class SobolSampleOpTest(test_util.TensorFlowTestCase):
s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32) s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32)
self.assertAllEqual([100, 10], self.evaluate(s).shape) self.assertAllEqual([100, 10], self.evaluate(s).shape)
def test_default_dtype(self):
# Create an op without specifying the dtype. Dtype should be float32 in
# this case.
s = math_ops.sobol_sample(10, 100)
self.assertEqual(dtypes.float32, s.dtype)
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()

View File

@ -422,7 +422,7 @@ tf_module {
} }
member_method { member_method {
name: "sobol_sample" name: "sobol_sample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
} }
member_method { member_method {
name: "softmax" name: "softmax"

View File

@ -3846,7 +3846,7 @@ tf_module {
} }
member_method { member_method {
name: "SobolSample" name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], " argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
} }
member_method { member_method {
name: "Softmax" name: "Softmax"

View File

@ -422,7 +422,7 @@ tf_module {
} }
member_method { member_method {
name: "sobol_sample" name: "sobol_sample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
} }
member_method { member_method {
name: "softmax" name: "softmax"

View File

@ -3846,7 +3846,7 @@ tf_module {
} }
member_method { member_method {
name: "SobolSample" name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], " argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
} }
member_method { member_method {
name: "Softmax" name: "Softmax"