Add tf.math.sobol_sample for generating Sobol sequences (CPU implementation only).
PiperOrigin-RevId: 286365254 Change-Id: Ia0c2482f4f264f36fe61db5f9c72f24db35faf65
This commit is contained in:
parent
ecce7990b9
commit
50a1c3be8b
@ -1254,7 +1254,9 @@ cc_library(
|
||||
cc_library(
|
||||
name = "dynamic_kernels_impl",
|
||||
visibility = [":__subpackages__"],
|
||||
deps = [],
|
||||
deps = [
|
||||
"//tensorflow/core/kernels:sobol_op",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
41
tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt
Normal file
41
tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt
Normal file
@ -0,0 +1,41 @@
|
||||
op {
|
||||
graph_op_name: "SobolSample"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "dim"
|
||||
description: <<END
|
||||
Positive scalar `Tensor` representing each sample's dimension.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_results"
|
||||
description: <<END
|
||||
Positive scalar `Tensor` of dtype int32. The number of Sobol points to return
|
||||
in the output.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "skip"
|
||||
description: <<END
|
||||
Positive scalar `Tensor` of dtype int32. The number of initial points of the
|
||||
Sobol sequence to skip.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
description: <<END
|
||||
The type of the sample. One of: `float32` or `float64`.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "samples"
|
||||
description: <<END
|
||||
`Tensor` of samples from Sobol sequence with `shape` [num_results, dim].
|
||||
END
|
||||
}
|
||||
summary: "Generates points from the Sobol sequence."
|
||||
description: <<END
|
||||
Creates a Sobol sequence with `num_results` samples. Each sample has dimension
|
||||
`dim`. Skips the first `skip` samples.
|
||||
END
|
||||
}
|
@ -6705,6 +6705,7 @@ filegroup(
|
||||
"decode_proto_op.cc",
|
||||
"encode_proto_op.cc",
|
||||
"rpc_op.cc",
|
||||
"sobol_op.cc",
|
||||
# Excluded due to experimental status:
|
||||
"debug_ops.*",
|
||||
"mutex_ops.*",
|
||||
@ -8278,3 +8279,16 @@ exports_files([
|
||||
"sparse_reshape_op.cc",
|
||||
"unary_ops_composition.cc",
|
||||
])
|
||||
|
||||
tf_kernel_library(
|
||||
name = "sobol_op",
|
||||
srcs = [
|
||||
"sobol_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
"@sobol_data",
|
||||
],
|
||||
)
|
||||
|
182
tensorflow/core/kernels/sobol_op.cc
Normal file
182
tensorflow/core/kernels/sobol_op.cc
Normal file
@ -0,0 +1,182 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Based on "Notes on generating Sobol sequences. August 2008" by Joe and Kuo.
|
||||
// [1] https://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-notes.pdf
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "sobol_data.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/platform_strings.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Embed the platform strings in this binary.
|
||||
TF_PLATFORM_STRINGS()
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
// Each thread will calculate at least kMinBlockSize points in the sequence.
|
||||
constexpr int kMinBlockSize = 512;
|
||||
|
||||
// Returns number of digits in binary representation of n.
|
||||
// Example: n=13. Binary representation is 1101. NumBinaryDigits(13) -> 4.
|
||||
int NumBinaryDigits(int n) {
|
||||
return static_cast<int>(std::log2(n) + 1);
|
||||
}
|
||||
|
||||
// Returns position of rightmost zero digit in binary representation of n.
|
||||
// Example: n=13. Binary representation is 1101. RightmostZeroBit(13) -> 1.
|
||||
int RightmostZeroBit(int n) {
|
||||
int k = 0;
|
||||
while (n & 1) {
|
||||
n >>= 1;
|
||||
++k;
|
||||
}
|
||||
return k;
|
||||
}
|
||||
|
||||
// Returns an integer representation of point `i` in the Sobol sequence of
|
||||
// dimension `dim` using the given direction numbers.
|
||||
Eigen::VectorXi GetFirstPoint(int i, int dim,
|
||||
const Eigen::MatrixXi& direction_numbers) {
|
||||
// Index variables used in this function, consistent with notation in [1].
|
||||
// i - point in the Sobol sequence
|
||||
// j - dimension
|
||||
// k - binary digit
|
||||
Eigen::VectorXi integer_sequence = Eigen::VectorXi::Zero(dim);
|
||||
// go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
|
||||
int gray_code = i ^ (i >> 1);
|
||||
int num_digits = NumBinaryDigits(i);
|
||||
for (int j = 0; j < dim; ++j) {
|
||||
for (int k = 0; k < num_digits; ++k) {
|
||||
if ((gray_code >> k) & 1) integer_sequence(j) ^= direction_numbers(j, k);
|
||||
}
|
||||
}
|
||||
return integer_sequence;
|
||||
}
|
||||
|
||||
// Calculates `num_results` Sobol points of dimension `dim` starting at the
|
||||
// point `start_point + skip` and writes them into `output` starting at point
|
||||
// `start_point`.
|
||||
template <typename T>
|
||||
void CalculateSobolSample(int32_t dim, int32_t num_results, int32_t skip,
|
||||
int32_t start_point,
|
||||
typename TTypes<T>::Flat output) {
|
||||
// Index variables used in this function, consistent with notation in [1].
|
||||
// i - point in the Sobol sequence
|
||||
// j - dimension
|
||||
// k - binary digit
|
||||
const int num_digits =
|
||||
NumBinaryDigits(skip + start_point + num_results + 1);
|
||||
Eigen::MatrixXi direction_numbers(dim, num_digits);
|
||||
|
||||
// Shift things so we can use integers everywhere. Before we write to output,
|
||||
// divide by constant to convert back to floats.
|
||||
const T normalizing_constant = 1./(1 << num_digits);
|
||||
for (int j = 0; j < dim; ++j) {
|
||||
for (int k = 0; k < num_digits; ++k) {
|
||||
direction_numbers(j, k) = sobol_data::kDirectionNumbers[j][k]
|
||||
<< (num_digits - k - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// If needed, skip ahead to the appropriate point in the sequence. Otherwise
|
||||
// we start with the first column of direction numbers.
|
||||
Eigen::VectorXi integer_sequence =
|
||||
(skip + start_point > 0)
|
||||
? GetFirstPoint(skip + start_point + 1, dim, direction_numbers)
|
||||
: direction_numbers.col(0);
|
||||
|
||||
for (int j = 0; j < dim; ++j) {
|
||||
output(start_point * dim + j) = integer_sequence(j) * normalizing_constant;
|
||||
}
|
||||
// go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
|
||||
for (int i = start_point + 1; i < num_results + start_point; ++i) {
|
||||
// The Gray code for the current point differs from the preceding one by
|
||||
// just a single bit -- the rightmost bit.
|
||||
int k = RightmostZeroBit(i + skip);
|
||||
// Update the current point from the preceding one with a single XOR
|
||||
// operation per dimension.
|
||||
for (int j = 0; j < dim; ++j) {
|
||||
integer_sequence(j) ^= direction_numbers(j, k);
|
||||
output(i * dim + j) = integer_sequence(j) * normalizing_constant;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Device, typename T>
|
||||
class SobolSampleOp : public OpKernel {
|
||||
public:
|
||||
explicit SobolSampleOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
int32_t dim = context->input(0).scalar<int32_t>()();
|
||||
int32_t num_results = context->input(1).scalar<int32_t>()();
|
||||
int32_t skip = context->input(2).scalar<int32_t>()();
|
||||
|
||||
OP_REQUIRES(context, dim >= 1,
|
||||
errors::InvalidArgument("dim must be at least one"));
|
||||
OP_REQUIRES(context, dim <= sobol_data::kMaxSobolDim,
|
||||
errors::InvalidArgument("dim must be at most ",
|
||||
sobol_data::kMaxSobolDim));
|
||||
OP_REQUIRES(context, num_results >= 1,
|
||||
errors::InvalidArgument("num_results must be at least one"));
|
||||
OP_REQUIRES(context, skip >= 0,
|
||||
errors::InvalidArgument("skip must be non-negative"));
|
||||
OP_REQUIRES(context,
|
||||
num_results < std::numeric_limits<int32_t>::max() - skip,
|
||||
errors::InvalidArgument("num_results+skip must be less than ",
|
||||
std::numeric_limits<int32_t>::max()));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(
|
||||
0, TensorShape({num_results, dim}), &output));
|
||||
auto output_flat = output->flat<T>();
|
||||
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||
*(context->device()->tensorflow_cpu_worker_threads());
|
||||
int num_threads = worker_threads.num_threads;
|
||||
int block_size = std::max(
|
||||
kMinBlockSize, static_cast<int>(std::ceil(
|
||||
static_cast<float>(num_results) / num_threads)));
|
||||
worker_threads.workers->TransformRangeConcurrently(
|
||||
block_size, num_results /* total */,
|
||||
[&dim, &skip, &output_flat](const int start, const int end) {
|
||||
CalculateSobolSample<T>(dim, end - start /* num_results */, skip,
|
||||
start, output_flat);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<double>("dtype"),
|
||||
SobolSampleOp<CPUDevice, double>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<float>("dtype"),
|
||||
SobolSampleOp<CPUDevice, float>);
|
||||
|
||||
} // namespace tensorflow
|
@ -1907,4 +1907,28 @@ REGISTER_OP("NextAfter")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
REGISTER_OP("SobolSample")
|
||||
.Input("dim: int32")
|
||||
.Input("num_results: int32")
|
||||
.Input("skip: int32")
|
||||
.Attr("dtype: {float, double} = DT_DOUBLE")
|
||||
.Output("samples: dtype")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
// inputs must be scalars
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
const Tensor* dim_t = c->input_tensor(0);
|
||||
const Tensor* num_results_t = c->input_tensor(1);
|
||||
if (dim_t == nullptr || num_results_t == nullptr) {
|
||||
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
}
|
||||
const int32 output_size =
|
||||
dim_t->scalar<int32>()() * num_results_t->scalar<int32>()();
|
||||
c->set_output(0, c->Vector(output_size));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -593,4 +593,15 @@ TEST(MathOpsTest, Bincount_ShapeFn) {
|
||||
INFER_OK(op, "[?];[];?", "[?]");
|
||||
INFER_OK(op, "[?];[];[?]", "[?]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, SobolSample) {
|
||||
ShapeInferenceTestOp op("SobolSample");
|
||||
|
||||
// All inputs should be scalar.
|
||||
INFER_ERROR("must be rank 0", op, "[1];?;?");
|
||||
INFER_ERROR("must be rank 0", op, "?;[1];?");
|
||||
INFER_ERROR("must be rank 0", op, "?;?;[1]");
|
||||
|
||||
INFER_OK(op, "[];[];[]", "[?]");
|
||||
}
|
||||
} // end namespace tensorflow
|
||||
|
@ -4828,6 +4828,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "sobol_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops/sobol_ops_test.py"],
|
||||
additional_deps = [
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":math_ops",
|
||||
":platform_test",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
kernels = [
|
||||
"//tensorflow/core/kernels:libtfkernel_sobol_op.so",
|
||||
],
|
||||
tags = ["no_windows_gpu"],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "special_math_ops_test",
|
||||
size = "medium",
|
||||
|
@ -4450,3 +4450,27 @@ def exp(x, name=None):
|
||||
|
||||
|
||||
# pylint: enable=g-docstring-has-escape
|
||||
|
||||
|
||||
@tf_export("math.sobol_sample")
|
||||
def sobol_sample(dim, num_results, skip=0, dtype=None, name=None):
|
||||
"""Generates points from the Sobol sequence.
|
||||
|
||||
Creates a Sobol sequence with `num_results` samples. Each sample has dimension
|
||||
`dim`. Skips the first `skip` samples.
|
||||
|
||||
Args:
|
||||
dim: Positive scalar `Tensor` representing each sample's dimension.
|
||||
num_results: Positive scalar `Tensor` of dtype int32. The number of Sobol
|
||||
points to return in the output.
|
||||
skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
|
||||
initial points of the Sobol sequence to skip. Default value is 0.
|
||||
dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`.
|
||||
Default value is determined by the C++ kernel.
|
||||
name: (Optional) Python `str` name prefixed to ops created by this function.
|
||||
|
||||
Returns:
|
||||
`Tensor` of samples from Sobol sequence with `shape` [num_results, dim].
|
||||
"""
|
||||
with ops.name_scope(name, "sobol", [dim, num_results, skip]):
|
||||
return gen_math_ops.sobol_sample(dim, num_results, skip, dtype=dtype)
|
||||
|
83
tensorflow/python/ops/sobol_ops_test.py
Normal file
83
tensorflow/python/ops/sobol_ops_test.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests Sobol sequence generator."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class SobolSampleOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
for dtype in [np.float64, np.float32]:
|
||||
expected = np.array([[.5, .5], [.75, .25], [.25, .75], [.375, .375]])
|
||||
sample = self.evaluate(math_ops.sobol_sample(2, 4, dtype=dtype))
|
||||
self.assertAllClose(expected, sample, 0.001)
|
||||
|
||||
def test_more_known_values(self):
|
||||
for dtype in [np.float64, np.float32]:
|
||||
sample = math_ops.sobol_sample(5, 31, dtype=dtype)
|
||||
expected = [[0.50, 0.50, 0.50, 0.50, 0.50],
|
||||
[0.75, 0.25, 0.25, 0.25, 0.75],
|
||||
[0.25, 0.75, 0.75, 0.75, 0.25],
|
||||
[0.375, 0.375, 0.625, 0.875, 0.375],
|
||||
[0.875, 0.875, 0.125, 0.375, 0.875],
|
||||
[0.625, 0.125, 0.875, 0.625, 0.625],
|
||||
[0.125, 0.625, 0.375, 0.125, 0.125],
|
||||
[0.1875, 0.3125, 0.9375, 0.4375, 0.5625],
|
||||
[0.6875, 0.8125, 0.4375, 0.9375, 0.0625],
|
||||
[0.9375, 0.0625, 0.6875, 0.1875, 0.3125],
|
||||
[0.4375, 0.5625, 0.1875, 0.6875, 0.8125],
|
||||
[0.3125, 0.1875, 0.3125, 0.5625, 0.9375],
|
||||
[0.8125, 0.6875, 0.8125, 0.0625, 0.4375],
|
||||
[0.5625, 0.4375, 0.0625, 0.8125, 0.1875],
|
||||
[0.0625, 0.9375, 0.5625, 0.3125, 0.6875],
|
||||
[0.09375, 0.46875, 0.46875, 0.65625, 0.28125],
|
||||
[0.59375, 0.96875, 0.96875, 0.15625, 0.78125],
|
||||
[0.84375, 0.21875, 0.21875, 0.90625, 0.53125],
|
||||
[0.34375, 0.71875, 0.71875, 0.40625, 0.03125],
|
||||
[0.46875, 0.09375, 0.84375, 0.28125, 0.15625],
|
||||
[0.96875, 0.59375, 0.34375, 0.78125, 0.65625],
|
||||
[0.71875, 0.34375, 0.59375, 0.03125, 0.90625],
|
||||
[0.21875, 0.84375, 0.09375, 0.53125, 0.40625],
|
||||
[0.15625, 0.15625, 0.53125, 0.84375, 0.84375],
|
||||
[0.65625, 0.65625, 0.03125, 0.34375, 0.34375],
|
||||
[0.90625, 0.40625, 0.78125, 0.59375, 0.09375],
|
||||
[0.40625, 0.90625, 0.28125, 0.09375, 0.59375],
|
||||
[0.28125, 0.28125, 0.15625, 0.21875, 0.71875],
|
||||
[0.78125, 0.78125, 0.65625, 0.71875, 0.21875],
|
||||
[0.53125, 0.03125, 0.40625, 0.46875, 0.46875],
|
||||
[0.03125, 0.53125, 0.90625, 0.96875, 0.96875]]
|
||||
self.assertAllClose(expected, self.evaluate(sample), .001)
|
||||
|
||||
def test_skip(self):
|
||||
dim = 10
|
||||
n = 50
|
||||
skip = 17
|
||||
for dtype in [np.float64, np.float32]:
|
||||
sample_noskip = math_ops.sobol_sample(dim, n + skip, dtype=dtype)
|
||||
sample_skip = math_ops.sobol_sample(dim, n, skip, dtype=dtype)
|
||||
|
||||
self.assertAllClose(
|
||||
self.evaluate(sample_noskip)[skip:, :], self.evaluate(sample_skip))
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -420,6 +420,10 @@ tf_module {
|
||||
name: "sinh"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "sobol_sample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "softmax"
|
||||
argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
@ -3836,6 +3836,10 @@ tf_module {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Softmax"
|
||||
argspec: "args=[\'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -420,6 +420,10 @@ tf_module {
|
||||
name: "sinh"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "sobol_sample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "softmax"
|
||||
argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
@ -3836,6 +3836,10 @@ tf_module {
|
||||
name: "SnapshotDataset"
|
||||
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "SobolSample"
|
||||
argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Softmax"
|
||||
argspec: "args=[\'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -161,6 +161,7 @@ genrule(
|
||||
"@png//:LICENSE",
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@snappy//:COPYING",
|
||||
"@sobol_data//:LICENSE",
|
||||
"@zlib_archive//:zlib.h",
|
||||
"@six_archive//:LICENSE",
|
||||
] + select({
|
||||
@ -232,6 +233,7 @@ genrule(
|
||||
"@png//:LICENSE",
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@snappy//:COPYING",
|
||||
"@sobol_data//:LICENSE",
|
||||
"@zlib_archive//:zlib.h",
|
||||
"@grpc//:LICENSE",
|
||||
"@grpc//third_party/address_sorting:LICENSE",
|
||||
|
@ -47,6 +47,7 @@ py_binary(
|
||||
|
||||
# Add dynamic kernel dso files here.
|
||||
DYNAMIC_LOADED_KERNELS = [
|
||||
"//tensorflow/core/kernels:libtfkernel_sobol_op.so",
|
||||
]
|
||||
|
||||
COMMON_PIP_DEPS = [
|
||||
@ -160,6 +161,7 @@ filegroup(
|
||||
"@com_google_protobuf//:LICENSE",
|
||||
"@six_archive//:LICENSE",
|
||||
"@snappy//:COPYING",
|
||||
"@sobol_data//:LICENSE",
|
||||
"@swig//:LICENSE",
|
||||
"@termcolor_archive//:COPYING.txt",
|
||||
"@zlib_archive//:zlib.h",
|
||||
|
@ -38,6 +38,7 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo")
|
||||
load("//third_party/pasta:workspace.bzl", pasta = "repo")
|
||||
load("//third_party/psimd:workspace.bzl", psimd = "repo")
|
||||
load("//third_party/pthreadpool:workspace.bzl", pthreadpool = "repo")
|
||||
load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo")
|
||||
|
||||
def initialize_third_party():
|
||||
""" Load third party repositories. See above load() statements. """
|
||||
@ -58,6 +59,7 @@ def initialize_third_party():
|
||||
pasta()
|
||||
psimd()
|
||||
pthreadpool()
|
||||
sobol_data()
|
||||
|
||||
# Sanitize a dependency so that it works correctly from code that includes
|
||||
# TensorFlow as a submodule.
|
||||
|
1
third_party/sobol_data/BUILD
vendored
Normal file
1
third_party/sobol_data/BUILD
vendored
Normal file
@ -0,0 +1 @@
|
||||
# This empty BUILD file is required to make Bazel treat this directory as a package.
|
10
third_party/sobol_data/BUILD.bazel
vendored
Normal file
10
third_party/sobol_data/BUILD.bazel
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "sobol_data",
|
||||
hdrs = ["sobol_data.h"],
|
||||
)
|
15
third_party/sobol_data/workspace.bzl
vendored
Normal file
15
third_party/sobol_data/workspace.bzl
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
"""Loads the sobol_data library, used by TF."""
|
||||
|
||||
load("//third_party:repo.bzl", "third_party_http_archive")
|
||||
|
||||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "sobol_data",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/sobol_data/archive/835a7d7b1ee3bc83e575e302a985c66ec4b65249.tar.gz",
|
||||
"https://github.com/joe-kuo/sobol_data/archive/835a7d7b1ee3bc83e575e302a985c66ec4b65249.tar.gz",
|
||||
],
|
||||
sha256 = "583d7b975e506c076fc579d9139530596906b9195b203d42361417e9aad79b73",
|
||||
strip_prefix = "sobol_data-835a7d7b1ee3bc83e575e302a985c66ec4b65249",
|
||||
build_file = "//third_party/sobol_data:BUILD.bazel",
|
||||
)
|
Loading…
Reference in New Issue
Block a user