Change the default dtype for SobolSample op to float32.
PiperOrigin-RevId: 291893007 Change-Id: Ica87ba033b5dcb7796d3b61f5a34eea3e026cfed
This commit is contained in:
parent
46644c6c58
commit
e2b70c0f80
tensorflow
core/ops
python/ops
tools/api/golden
@ -20,7 +20,7 @@ op {
|
|||||||
name: "dtype"
|
name: "dtype"
|
||||||
type: "type"
|
type: "type"
|
||||||
default_value {
|
default_value {
|
||||||
type: DT_DOUBLE
|
type: DT_FLOAT
|
||||||
}
|
}
|
||||||
allowed_values {
|
allowed_values {
|
||||||
list {
|
list {
|
||||||
|
@ -1918,7 +1918,7 @@ REGISTER_OP("SobolSample")
|
|||||||
.Input("dim: int32")
|
.Input("dim: int32")
|
||||||
.Input("num_results: int32")
|
.Input("num_results: int32")
|
||||||
.Input("skip: int32")
|
.Input("skip: int32")
|
||||||
.Attr("dtype: {float, double} = DT_DOUBLE")
|
.Attr("dtype: {float, double} = DT_FLOAT")
|
||||||
.Output("samples: dtype")
|
.Output("samples: dtype")
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
|
@ -4564,7 +4564,7 @@ def exp(x, name=None):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("math.sobol_sample")
|
@tf_export("math.sobol_sample")
|
||||||
def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
|
def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
|
||||||
"""Generates points from the Sobol sequence.
|
"""Generates points from the Sobol sequence.
|
||||||
|
|
||||||
Creates a Sobol sequence with `num_results` samples. Each sample has dimension
|
Creates a Sobol sequence with `num_results` samples. Each sample has dimension
|
||||||
@ -4576,8 +4576,8 @@ def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
|
|||||||
points to return in the output.
|
points to return in the output.
|
||||||
skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
|
skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
|
||||||
initial points of the Sobol sequence to skip. Default value is 0.
|
initial points of the Sobol sequence to skip. Default value is 0.
|
||||||
dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`.
|
dtype: (Optional) The `tf.Dtype` of the sample. One of: `tf.float32` or
|
||||||
Default value is determined by the C++ kernel.
|
`tf.float64`. Defaults to `tf.float32`.
|
||||||
name: (Optional) Python `str` name prefixed to ops created by this function.
|
name: (Optional) Python `str` name prefixed to ops created by this function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -124,6 +124,11 @@ class SobolSampleOpTest(test_util.TensorFlowTestCase):
|
|||||||
s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32)
|
s = math_ops.sobol_sample(10, 100, dtype=dtypes.float32)
|
||||||
self.assertAllEqual([100, 10], self.evaluate(s).shape)
|
self.assertAllEqual([100, 10], self.evaluate(s).shape)
|
||||||
|
|
||||||
|
def test_default_dtype(self):
|
||||||
|
# Create an op without specifying the dtype. Dtype should be float32 in
|
||||||
|
# this case.
|
||||||
|
s = math_ops.sobol_sample(10, 100)
|
||||||
|
self.assertEqual(dtypes.float32, s.dtype)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -422,7 +422,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "sobol_sample"
|
name: "sobol_sample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "softmax"
|
name: "softmax"
|
||||||
|
@ -3846,7 +3846,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SobolSample"
|
name: "SobolSample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Softmax"
|
name: "Softmax"
|
||||||
|
@ -422,7 +422,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "sobol_sample"
|
name: "sobol_sample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "softmax"
|
name: "softmax"
|
||||||
|
@ -3846,7 +3846,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "SobolSample"
|
name: "SobolSample"
|
||||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
|
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Softmax"
|
name: "Softmax"
|
||||||
|
Loading…
Reference in New Issue
Block a user