Skip to content

Commit

Permalink
Support configurable LMUL in VectorXoshiro.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620864942
  • Loading branch information
tgale96 authored and copybara-github committed Apr 2, 2024
1 parent 4ce48ca commit 4155c08
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
29 changes: 16 additions & 13 deletions hwy/contrib/random/random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,21 @@ class Xoshiro {

} // namespace internal

template <int kPow2 = 1>
class VectorXoshiro {
private:
using VU64 = Vec<ScalableTag<std::uint64_t>>;
using TagU64 = ScalableTag<std::uint64_t, kPow2>;
using TagF64 = ScalableTag<double, kPow2>;

using VU64 = Vec<TagU64>;
using StateType = AlignedNDArray<std::uint64_t, 2>;
#if HWY_HAVE_FLOAT64
using VF64 = Vec<ScalableTag<double>>;
using VF64 = Vec<TagF64>;
#endif
public:
explicit VectorXoshiro(const std::uint64_t seed,
const std::uint64_t threadNumber = 0)
: state_{{internal::Xoshiro::StateSize(),
Lanes(ScalableTag<std::uint64_t>{})}},
: state_{{internal::Xoshiro::StateSize(), Lanes(TagU64{})}},
streams{state_.shape().back()} {
internal::Xoshiro xoshiro{seed};

Expand All @@ -202,7 +205,7 @@ class VectorXoshiro {

AlignedVector<std::uint64_t> operator()(const std::size_t n) {
AlignedVector<std::uint64_t> result(n);
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand All @@ -221,7 +224,7 @@ class VectorXoshiro {
template <std::uint64_t N>
std::array<std::uint64_t, N> operator()() noexcept {
alignas(HWY_ALIGNMENT) std::array<std::uint64_t, N> result;
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand All @@ -246,7 +249,7 @@ class VectorXoshiro {
#if HWY_HAVE_FLOAT64

HWY_INLINE VF64 Uniform() noexcept {
const ScalableTag<double> real_tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);
const auto bits = ShiftRight<11>(Next());
const auto real = ConvertTo(real_tag, bits);
Expand All @@ -255,8 +258,8 @@ class VectorXoshiro {

AlignedVector<double> Uniform(const std::size_t n) {
AlignedVector<double> result(n);
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const TagU64 tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);

auto s0 = Load(tag, state_[{0}].data());
Expand All @@ -282,8 +285,8 @@ class VectorXoshiro {
template <std::uint64_t N>
std::array<double, N> Uniform() noexcept {
alignas(HWY_ALIGNMENT) std::array<double, N> result;
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
const TagU64 tag{};
const TagF64 real_tag{};
const auto MUL_VALUE = Set(real_tag, internal::kMulConst);

auto s0 = Load(tag, state_[{0}].data());
Expand Down Expand Up @@ -326,7 +329,7 @@ class VectorXoshiro {
}

HWY_INLINE VU64 Next() noexcept {
const ScalableTag<std::uint64_t> tag{};
const TagU64 tag{};
auto s0 = Load(tag, state_[{0}].data());
auto s1 = Load(tag, state_[{1}].data());
auto s2 = Load(tag, state_[{2}].data());
Expand Down Expand Up @@ -368,7 +371,7 @@ class CachedXoshiro {
}

private:
VectorXoshiro generator_;
VectorXoshiro</*kPow2=*/1> generator_;
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
std::size_t index_;

Expand Down
16 changes: 8 additions & 8 deletions hwy/contrib/random/random_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ std::uint64_t GetSeed() { return static_cast<uint64_t>(std::time(nullptr)); }
void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
const size_t size) {
const ScalableTag<std::uint64_t> d;
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
for (size_t i = 0; i < size; i += Lanes(d)) {
Store(generator(), d, result + i);
}
Expand All @@ -40,7 +40,7 @@ void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,
const size_t size) {
const ScalableTag<double> d;
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
for (size_t i = 0; i < size; i += Lanes(d)) {
Store(generator.Uniform(), d, result + i);
}
Expand All @@ -49,7 +49,7 @@ void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,

void TestSeeding() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
internal::Xoshiro reference{seed};
const auto& state = generator.GetState();
const ScalableTag<std::uint64_t> d;
Expand All @@ -72,7 +72,7 @@ void TestSeeding() {
void TestMultiThreadSeeding() {
const std::uint64_t seed = GetSeed();
const std::uint64_t threadId = std::random_device()() % 1000;
VectorXoshiro generator{seed, threadId};
VectorXoshiro<> generator{seed, threadId};
internal::Xoshiro reference{seed};

for (std::size_t i = 0UL; i < threadId; ++i) {
Expand Down Expand Up @@ -146,7 +146,7 @@ void TestUniformDist() {

void TestNextNRandomUint64() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.operator()(tests);
std::vector<internal::Xoshiro> reference;
reference.emplace_back(seed);
Expand Down Expand Up @@ -174,7 +174,7 @@ void TestNextNRandomUint64() {

void TestNextFixedNRandomUint64() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.operator()<tests>();
std::vector<internal::Xoshiro> reference;
reference.emplace_back(seed);
Expand Down Expand Up @@ -203,7 +203,7 @@ void TestNextFixedNRandomUint64() {
#if HWY_HAVE_FLOAT64
void TestNextNUniformDist() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.Uniform(tests);
internal::Xoshiro reference{seed};
const ScalableTag<double> d;
Expand All @@ -222,7 +222,7 @@ void TestNextNUniformDist() {

void TestNextFixedNUniformDist() {
const std::uint64_t seed = GetSeed();
VectorXoshiro generator{seed};
VectorXoshiro<> generator{seed};
const auto result_array = generator.Uniform<tests>();
internal::Xoshiro reference{seed};
const ScalableTag<double> d;
Expand Down

0 comments on commit 4155c08

Please sign in to comment.