From 50a1c3be8b7251b9a93e34b658176887cc53890c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 19 Dec 2019 03:46:07 -0800 Subject: [PATCH] Add tf.math.sobol_sample for generating Sobol sequences (CPU implementation only). PiperOrigin-RevId: 286365254 Change-Id: Ia0c2482f4f264f36fe61db5f9c72f24db35faf65 --- tensorflow/core/BUILD | 4 +- .../base_api/api_def_SobolSample.pbtxt | 41 ++++ tensorflow/core/kernels/BUILD | 14 ++ tensorflow/core/kernels/sobol_op.cc | 182 ++++++++++++++++++ tensorflow/core/ops/math_ops.cc | 24 +++ tensorflow/core/ops/math_ops_test.cc | 11 ++ tensorflow/python/BUILD | 17 ++ tensorflow/python/ops/math_ops.py | 24 +++ tensorflow/python/ops/sobol_ops_test.py | 83 ++++++++ .../tools/api/golden/v1/tensorflow.math.pbtxt | 4 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../tools/api/golden/v2/tensorflow.math.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + tensorflow/tools/lib_package/BUILD | 2 + tensorflow/tools/pip_package/BUILD | 2 + tensorflow/workspace.bzl | 2 + third_party/sobol_data/BUILD | 1 + third_party/sobol_data/BUILD.bazel | 10 + third_party/sobol_data/workspace.bzl | 15 ++ 19 files changed, 447 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt create mode 100644 tensorflow/core/kernels/sobol_op.cc create mode 100644 tensorflow/python/ops/sobol_ops_test.py create mode 100644 third_party/sobol_data/BUILD create mode 100644 third_party/sobol_data/BUILD.bazel create mode 100644 third_party/sobol_data/workspace.bzl diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5e7e802db18..9a1b0a589bc 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1254,7 +1254,9 @@ cc_library( cc_library( name = "dynamic_kernels_impl", visibility = [":__subpackages__"], - deps = [], + deps = [ + "//tensorflow/core/kernels:sobol_op", + ], ) cc_library( diff --git a/tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt b/tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt new file mode 100644 index 00000000000..b80fff30fae --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SobolSample.pbtxt @@ -0,0 +1,41 @@ +op { + graph_op_name: "SobolSample" + visibility: HIDDEN + in_arg { + name: "dim" + description: < +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "sobol_data.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/platform_strings.h" + +namespace tensorflow { + +// Embed the platform strings in this binary. +TF_PLATFORM_STRINGS() + +typedef Eigen::ThreadPoolDevice CPUDevice; + +namespace { + +// Each thread will calculate at least kMinBlockSize points in the sequence. +constexpr int kMinBlockSize = 512; + +// Returns number of digits in binary representation of n. +// Example: n=13. Binary representation is 1101. NumBinaryDigits(13) -> 4. +int NumBinaryDigits(int n) { + return static_cast(std::log2(n) + 1); +} + +// Returns position of rightmost zero digit in binary representation of n. +// Example: n=13. Binary representation is 1101. RightmostZeroBit(13) -> 1. +int RightmostZeroBit(int n) { + int k = 0; + while (n & 1) { + n >>= 1; + ++k; + } + return k; +} + +// Returns an integer representation of point `i` in the Sobol sequence of +// dimension `dim` using the given direction numbers. +Eigen::VectorXi GetFirstPoint(int i, int dim, + const Eigen::MatrixXi& direction_numbers) { + // Index variables used in this function, consistent with notation in [1]. + // i - point in the Sobol sequence + // j - dimension + // k - binary digit + Eigen::VectorXi integer_sequence = Eigen::VectorXi::Zero(dim); + // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences + int gray_code = i ^ (i >> 1); + int num_digits = NumBinaryDigits(i); + for (int j = 0; j < dim; ++j) { + for (int k = 0; k < num_digits; ++k) { + if ((gray_code >> k) & 1) integer_sequence(j) ^= direction_numbers(j, k); + } + } + return integer_sequence; +} + +// Calculates `num_results` Sobol points of dimension `dim` starting at the +// point `start_point + skip` and writes them into `output` starting at point +// `start_point`. +template +void CalculateSobolSample(int32_t dim, int32_t num_results, int32_t skip, + int32_t start_point, + typename TTypes::Flat output) { + // Index variables used in this function, consistent with notation in [1]. + // i - point in the Sobol sequence + // j - dimension + // k - binary digit + const int num_digits = + NumBinaryDigits(skip + start_point + num_results + 1); + Eigen::MatrixXi direction_numbers(dim, num_digits); + + // Shift things so we can use integers everywhere. Before we write to output, + // divide by constant to convert back to floats. + const T normalizing_constant = 1./(1 << num_digits); + for (int j = 0; j < dim; ++j) { + for (int k = 0; k < num_digits; ++k) { + direction_numbers(j, k) = sobol_data::kDirectionNumbers[j][k] + << (num_digits - k - 1); + } + } + + // If needed, skip ahead to the appropriate point in the sequence. Otherwise + // we start with the first column of direction numbers. + Eigen::VectorXi integer_sequence = + (skip + start_point > 0) + ? GetFirstPoint(skip + start_point + 1, dim, direction_numbers) + : direction_numbers.col(0); + + for (int j = 0; j < dim; ++j) { + output(start_point * dim + j) = integer_sequence(j) * normalizing_constant; + } + // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences + for (int i = start_point + 1; i < num_results + start_point; ++i) { + // The Gray code for the current point differs from the preceding one by + // just a single bit -- the rightmost bit. + int k = RightmostZeroBit(i + skip); + // Update the current point from the preceding one with a single XOR + // operation per dimension. + for (int j = 0; j < dim; ++j) { + integer_sequence(j) ^= direction_numbers(j, k); + output(i * dim + j) = integer_sequence(j) * normalizing_constant; + } + } +} + +} // namespace + +template +class SobolSampleOp : public OpKernel { + public: + explicit SobolSampleOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + int32_t dim = context->input(0).scalar()(); + int32_t num_results = context->input(1).scalar()(); + int32_t skip = context->input(2).scalar()(); + + OP_REQUIRES(context, dim >= 1, + errors::InvalidArgument("dim must be at least one")); + OP_REQUIRES(context, dim <= sobol_data::kMaxSobolDim, + errors::InvalidArgument("dim must be at most ", + sobol_data::kMaxSobolDim)); + OP_REQUIRES(context, num_results >= 1, + errors::InvalidArgument("num_results must be at least one")); + OP_REQUIRES(context, skip >= 0, + errors::InvalidArgument("skip must be non-negative")); + OP_REQUIRES(context, + num_results < std::numeric_limits::max() - skip, + errors::InvalidArgument("num_results+skip must be less than ", + std::numeric_limits::max())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({num_results, dim}), &output)); + auto output_flat = output->flat(); + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + int num_threads = worker_threads.num_threads; + int block_size = std::max( + kMinBlockSize, static_cast(std::ceil( + static_cast(num_results) / num_threads))); + worker_threads.workers->TransformRangeConcurrently( + block_size, num_results /* total */, + [&dim, &skip, &output_flat](const int start, const int end) { + CalculateSobolSample(dim, end - start /* num_results */, skip, + start, output_flat); + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("SobolSample").Device(DEVICE_CPU).TypeConstraint("dtype"), + SobolSampleOp); +REGISTER_KERNEL_BUILDER( + Name("SobolSample").Device(DEVICE_CPU).TypeConstraint("dtype"), + SobolSampleOp); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index ccdcf0b76e6..d8be0b265c4 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1907,4 +1907,28 @@ REGISTER_OP("NextAfter") .Output("output: T") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +REGISTER_OP("SobolSample") + .Input("dim: int32") + .Input("num_results: int32") + .Input("skip: int32") + .Attr("dtype: {float, double} = DT_DOUBLE") + .Output("samples: dtype") + .SetShapeFn([](shape_inference::InferenceContext* c) { + ShapeHandle unused; + // inputs must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + const Tensor* dim_t = c->input_tensor(0); + const Tensor* num_results_t = c->input_tensor(1); + if (dim_t == nullptr || num_results_t == nullptr) { + c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + } + const int32 output_size = + dim_t->scalar()() * num_results_t->scalar()(); + c->set_output(0, c->Vector(output_size)); + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 7ebd7889a35..7c8989f8c9b 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -593,4 +593,15 @@ TEST(MathOpsTest, Bincount_ShapeFn) { INFER_OK(op, "[?];[];?", "[?]"); INFER_OK(op, "[?];[];[?]", "[?]"); } + +TEST(MathOpsTest, SobolSample) { + ShapeInferenceTestOp op("SobolSample"); + + // All inputs should be scalar. + INFER_ERROR("must be rank 0", op, "[1];?;?"); + INFER_ERROR("must be rank 0", op, "?;[1];?"); + INFER_ERROR("must be rank 0", op, "?;?;[1]"); + + INFER_OK(op, "[];[];[]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a424ba25ceb..47e989341e0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4828,6 +4828,23 @@ py_test( ], ) +cuda_py_test( + name = "sobol_ops_test", + size = "small", + srcs = ["ops/sobol_ops_test.py"], + additional_deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":math_ops", + ":platform_test", + "//third_party/py/numpy", + ], + kernels = [ + "//tensorflow/core/kernels:libtfkernel_sobol_op.so", + ], + tags = ["no_windows_gpu"], +) + cuda_py_test( name = "special_math_ops_test", size = "medium", diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 20151e7228b..2db32c691c1 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -4450,3 +4450,27 @@ def exp(x, name=None): # pylint: enable=g-docstring-has-escape + + +@tf_export("math.sobol_sample") +def sobol_sample(dim, num_results, skip=0, dtype=None, name=None): + """Generates points from the Sobol sequence. + + Creates a Sobol sequence with `num_results` samples. Each sample has dimension + `dim`. Skips the first `skip` samples. + + Args: + dim: Positive scalar `Tensor` representing each sample's dimension. + num_results: Positive scalar `Tensor` of dtype int32. The number of Sobol + points to return in the output. + skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of + initial points of the Sobol sequence to skip. Default value is 0. + dtype: (Optional) The dtype of the sample. One of: `float32` or `float64`. + Default value is determined by the C++ kernel. + name: (Optional) Python `str` name prefixed to ops created by this function. + + Returns: + `Tensor` of samples from Sobol sequence with `shape` [num_results, dim]. + """ + with ops.name_scope(name, "sobol", [dim, num_results, skip]): + return gen_math_ops.sobol_sample(dim, num_results, skip, dtype=dtype) diff --git a/tensorflow/python/ops/sobol_ops_test.py b/tensorflow/python/ops/sobol_ops_test.py new file mode 100644 index 00000000000..3a9e52ad47d --- /dev/null +++ b/tensorflow/python/ops/sobol_ops_test.py @@ -0,0 +1,83 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests Sobol sequence generator.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class SobolSampleOpTest(test_util.TensorFlowTestCase): + + def test_basic(self): + for dtype in [np.float64, np.float32]: + expected = np.array([[.5, .5], [.75, .25], [.25, .75], [.375, .375]]) + sample = self.evaluate(math_ops.sobol_sample(2, 4, dtype=dtype)) + self.assertAllClose(expected, sample, 0.001) + + def test_more_known_values(self): + for dtype in [np.float64, np.float32]: + sample = math_ops.sobol_sample(5, 31, dtype=dtype) + expected = [[0.50, 0.50, 0.50, 0.50, 0.50], + [0.75, 0.25, 0.25, 0.25, 0.75], + [0.25, 0.75, 0.75, 0.75, 0.25], + [0.375, 0.375, 0.625, 0.875, 0.375], + [0.875, 0.875, 0.125, 0.375, 0.875], + [0.625, 0.125, 0.875, 0.625, 0.625], + [0.125, 0.625, 0.375, 0.125, 0.125], + [0.1875, 0.3125, 0.9375, 0.4375, 0.5625], + [0.6875, 0.8125, 0.4375, 0.9375, 0.0625], + [0.9375, 0.0625, 0.6875, 0.1875, 0.3125], + [0.4375, 0.5625, 0.1875, 0.6875, 0.8125], + [0.3125, 0.1875, 0.3125, 0.5625, 0.9375], + [0.8125, 0.6875, 0.8125, 0.0625, 0.4375], + [0.5625, 0.4375, 0.0625, 0.8125, 0.1875], + [0.0625, 0.9375, 0.5625, 0.3125, 0.6875], + [0.09375, 0.46875, 0.46875, 0.65625, 0.28125], + [0.59375, 0.96875, 0.96875, 0.15625, 0.78125], + [0.84375, 0.21875, 0.21875, 0.90625, 0.53125], + [0.34375, 0.71875, 0.71875, 0.40625, 0.03125], + [0.46875, 0.09375, 0.84375, 0.28125, 0.15625], + [0.96875, 0.59375, 0.34375, 0.78125, 0.65625], + [0.71875, 0.34375, 0.59375, 0.03125, 0.90625], + [0.21875, 0.84375, 0.09375, 0.53125, 0.40625], + [0.15625, 0.15625, 0.53125, 0.84375, 0.84375], + [0.65625, 0.65625, 0.03125, 0.34375, 0.34375], + [0.90625, 0.40625, 0.78125, 0.59375, 0.09375], + [0.40625, 0.90625, 0.28125, 0.09375, 0.59375], + [0.28125, 0.28125, 0.15625, 0.21875, 0.71875], + [0.78125, 0.78125, 0.65625, 0.71875, 0.21875], + [0.53125, 0.03125, 0.40625, 0.46875, 0.46875], + [0.03125, 0.53125, 0.90625, 0.96875, 0.96875]] + self.assertAllClose(expected, self.evaluate(sample), .001) + + def test_skip(self): + dim = 10 + n = 50 + skip = 17 + for dtype in [np.float64, np.float32]: + sample_noskip = math_ops.sobol_sample(dim, n + skip, dtype=dtype) + sample_skip = math_ops.sobol_sample(dim, n, skip, dtype=dtype) + + self.assertAllClose( + self.evaluate(sample_noskip)[skip:, :], self.evaluate(sample_skip)) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt index c904681f633..c24b1c38179 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt @@ -420,6 +420,10 @@ tf_module { name: "sinh" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "sobol_sample" + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + } member_method { name: "softmax" argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 98837bf3b63..2441232462d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3836,6 +3836,10 @@ tf_module { name: "SnapshotDataset" argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], " } + member_method { + name: "SobolSample" + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } member_method { name: "Softmax" argspec: "args=[\'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index 2ec2ab27476..33828112832 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -420,6 +420,10 @@ tf_module { name: "sinh" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "sobol_sample" + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], " + } member_method { name: "softmax" argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 98837bf3b63..2441232462d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3836,6 +3836,10 @@ tf_module { name: "SnapshotDataset" argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], " } + member_method { + name: "SobolSample" + argspec: "args=[\'dim\', \'num_results\', \'skip\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " + } member_method { name: "Softmax" argspec: "args=[\'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index d110cd114b5..b8bb7914e84 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -161,6 +161,7 @@ genrule( "@png//:LICENSE", "@com_google_protobuf//:LICENSE", "@snappy//:COPYING", + "@sobol_data//:LICENSE", "@zlib_archive//:zlib.h", "@six_archive//:LICENSE", ] + select({ @@ -232,6 +233,7 @@ genrule( "@png//:LICENSE", "@com_google_protobuf//:LICENSE", "@snappy//:COPYING", + "@sobol_data//:LICENSE", "@zlib_archive//:zlib.h", "@grpc//:LICENSE", "@grpc//third_party/address_sorting:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 891866aaa07..f32dce02faf 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -47,6 +47,7 @@ py_binary( # Add dynamic kernel dso files here. DYNAMIC_LOADED_KERNELS = [ + "//tensorflow/core/kernels:libtfkernel_sobol_op.so", ] COMMON_PIP_DEPS = [ @@ -160,6 +161,7 @@ filegroup( "@com_google_protobuf//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", + "@sobol_data//:LICENSE", "@swig//:LICENSE", "@termcolor_archive//:COPYING.txt", "@zlib_archive//:zlib.h", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index de5c50588ca..268844e24ab 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -38,6 +38,7 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") load("//third_party/psimd:workspace.bzl", psimd = "repo") load("//third_party/pthreadpool:workspace.bzl", pthreadpool = "repo") +load("//third_party/sobol_data:workspace.bzl", sobol_data = "repo") def initialize_third_party(): """ Load third party repositories. See above load() statements. """ @@ -58,6 +59,7 @@ def initialize_third_party(): pasta() psimd() pthreadpool() + sobol_data() # Sanitize a dependency so that it works correctly from code that includes # TensorFlow as a submodule. diff --git a/third_party/sobol_data/BUILD b/third_party/sobol_data/BUILD new file mode 100644 index 00000000000..82bab3ffd96 --- /dev/null +++ b/third_party/sobol_data/BUILD @@ -0,0 +1 @@ +# This empty BUILD file is required to make Bazel treat this directory as a package. diff --git a/third_party/sobol_data/BUILD.bazel b/third_party/sobol_data/BUILD.bazel new file mode 100644 index 00000000000..09b51dfbe39 --- /dev/null +++ b/third_party/sobol_data/BUILD.bazel @@ -0,0 +1,10 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "sobol_data", + hdrs = ["sobol_data.h"], +) diff --git a/third_party/sobol_data/workspace.bzl b/third_party/sobol_data/workspace.bzl new file mode 100644 index 00000000000..71840df7b9b --- /dev/null +++ b/third_party/sobol_data/workspace.bzl @@ -0,0 +1,15 @@ +"""Loads the sobol_data library, used by TF.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "sobol_data", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/sobol_data/archive/835a7d7b1ee3bc83e575e302a985c66ec4b65249.tar.gz", + "https://github.com/joe-kuo/sobol_data/archive/835a7d7b1ee3bc83e575e302a985c66ec4b65249.tar.gz", + ], + sha256 = "583d7b975e506c076fc579d9139530596906b9195b203d42361417e9aad79b73", + strip_prefix = "sobol_data-835a7d7b1ee3bc83e575e302a985c66ec4b65249", + build_file = "//third_party/sobol_data:BUILD.bazel", + )