diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 095daf70498..4e2866865a2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -522,6 +522,7 @@ cc_library( cc_library( name = "array", + srcs = ["array.cc"], hdrs = ["array.h"], deps = [ ":status", diff --git a/tensorflow/compiler/xla/array.cc b/tensorflow/compiler/xla/array.cc new file mode 100644 index 00000000000..bd958fc0d18 --- /dev/null +++ b/tensorflow/compiler/xla/array.cc @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array.h" + +namespace xla { + +// Specialization of FillRandom() method for complex64 type. Uses real part of +// the stddev parameter as the standard deviation value. +template <> +void Array::FillRandom(const complex64& stddev, const double mean, + const int seed) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, std::real(stddev)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = complex64(distribution(g), distribution(g)); + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 2ab45c75049..67bad0f8af7 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -575,6 +575,12 @@ class Array { std::unique_ptr values_; }; +// Specialization of FillRandom() method for complex64 type. Uses real part of +// the stddev parameter as the standard deviation value. +template <> +void Array::FillRandom(const complex64& stddev, const double mean, + const int seed); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 1863c78e7e1..281401dfbde 100755 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1103,8 +1103,7 @@ class HloConvolutionInstruction : public HloInstruction { void set_feature_group_count(int64 num_feature_groups) { feature_group_count_ = num_feature_groups; } - // The number of feature groups. Must be a divisor of the input batch - // dimension. + // The number of batch groups. Must be a divisor of the input batch dimension. int64 batch_group_count() const { return batch_group_count_; } void set_batch_group_count(int64 num_batch_groups) { batch_group_count_ = num_batch_groups; @@ -1138,8 +1137,7 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; - // The number of feature groups. Must be a divisor of the input batch - // dimension. + // The number of batch groups. Must be a divisor of the input batch dimension. int64 batch_group_count_; // Describes the window used for a convolution. Window window_;