Change the default dtype for SobolSample op to float32.

PiperOrigin-RevId: 291893007
Change-Id: Ica87ba033b5dcb7796d3b61f5a34eea3e026cfed
This commit is contained in:
Matej Rizman 2020-01-28 02:22:48 -08:00 committed by TensorFlower Gardener
parent 46644c6c58
commit e2b70c0f80
8 changed files with 14 additions and 9 deletions

View File

@ -20,7 +20,7 @@ op {
name: "dtype"
type: "type"
default_value {
type: DT_DOUBLE
type: DT_FLOAT
}
allowed_values {
list {

View File

@ -1918,7 +1918,7 @@ REGISTER_OP("SobolSample")
.Input("dim: int32")
.Input("num_results: int32")
.Input("skip: int32")
.Attr("dtype: {float, double} = DT_DOUBLE")
.Attr("dtype: {float, double} = DT_FLOAT")
.Output("samples: dtype")
.SetShapeFn([](shape_inference::InferenceContext* c) {
ShapeHandle unused;

View File

@ -4564,7 +4564,7 @@ def exp(x, name=None):
@tf_export("math.sobol_sample")
def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
"""Generates points from the Sobol sequence.
Creates a Sobol sequence with `num_results` samples. Each sample has dimension
@ -4576,8 +4576,8 @@ def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
points to return in the output.
skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
initial points of the Sobol sequence to skip. Default value is 0.
dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`.
Default value is determined by the C++ kernel.
dtype: (Optional) The `tf.Dtype` of the sample. One of: `tf.float32` or
`tf.float64`. Defaults to `tf.float32`.
name: (Optional) Python `str` name prefixed to ops created by this function.
Returns:

View File

@ -124,6 +124,11 @@ class SobolSampleOpTest(test_util.TensorFlowTestCase):
s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32)
self.assertAllEqual([100, 10], self.evaluate(s).shape)
def test_default_dtype(self):
# Create an op without specifying the dtype. Dtype should be float32 in
# this case.
s = math_ops.sobol_sample(10, 100)
self.assertEqual(dtypes.float32, s.dtype)
if __name__ == '__main__':
googletest.main()

View File

@ -422,7 +422,7 @@ tf_module {
}
member_method {
name: "sobol_sample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "softmax"

View File

@ -3846,7 +3846,7 @@ tf_module {
}
member_method {
name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Softmax"

View File

@ -422,7 +422,7 @@ tf_module {
}
member_method {
name: "sobol_sample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "softmax"

View File

@ -3846,7 +3846,7 @@ tf_module {
}
member_method {
name: "SobolSample"
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
name: "Softmax"