diff --git a/hwy/contrib/random/random-inl.h b/hwy/contrib/random/random-inl.h index 4779ef5a47..3f75603959 100644 --- a/hwy/contrib/random/random-inl.h +++ b/hwy/contrib/random/random-inl.h @@ -170,18 +170,21 @@ class Xoshiro { } // namespace internal +template class VectorXoshiro { private: - using VU64 = Vec>; + using TagU64 = ScalableTag; + using TagF64 = ScalableTag; + + using VU64 = Vec; using StateType = AlignedNDArray; #if HWY_HAVE_FLOAT64 - using VF64 = Vec>; + using VF64 = Vec; #endif public: explicit VectorXoshiro(const std::uint64_t seed, const std::uint64_t threadNumber = 0) - : state_{{internal::Xoshiro::StateSize(), - Lanes(ScalableTag{})}}, + : state_{{internal::Xoshiro::StateSize(), Lanes(TagU64{})}}, streams{state_.shape().back()} { internal::Xoshiro xoshiro{seed}; @@ -202,7 +205,7 @@ class VectorXoshiro { AlignedVector operator()(const std::size_t n) { AlignedVector result(n); - const ScalableTag tag{}; + const TagU64 tag{}; auto s0 = Load(tag, state_[{0}].data()); auto s1 = Load(tag, state_[{1}].data()); auto s2 = Load(tag, state_[{2}].data()); @@ -221,7 +224,7 @@ class VectorXoshiro { template std::array operator()() noexcept { alignas(HWY_ALIGNMENT) std::array result; - const ScalableTag tag{}; + const TagU64 tag{}; auto s0 = Load(tag, state_[{0}].data()); auto s1 = Load(tag, state_[{1}].data()); auto s2 = Load(tag, state_[{2}].data()); @@ -246,7 +249,7 @@ class VectorXoshiro { #if HWY_HAVE_FLOAT64 HWY_INLINE VF64 Uniform() noexcept { - const ScalableTag real_tag{}; + const TagF64 real_tag{}; const auto MUL_VALUE = Set(real_tag, internal::kMulConst); const auto bits = ShiftRight<11>(Next()); const auto real = ConvertTo(real_tag, bits); @@ -255,8 +258,8 @@ class VectorXoshiro { AlignedVector Uniform(const std::size_t n) { AlignedVector result(n); - const ScalableTag tag{}; - const ScalableTag real_tag{}; + const TagU64 tag{}; + const TagF64 real_tag{}; const auto MUL_VALUE = Set(real_tag, internal::kMulConst); auto s0 = Load(tag, state_[{0}].data()); @@ -282,8 +285,8 @@ class VectorXoshiro { template std::array Uniform() noexcept { alignas(HWY_ALIGNMENT) std::array result; - const ScalableTag tag{}; - const ScalableTag real_tag{}; + const TagU64 tag{}; + const TagF64 real_tag{}; const auto MUL_VALUE = Set(real_tag, internal::kMulConst); auto s0 = Load(tag, state_[{0}].data()); @@ -326,7 +329,7 @@ class VectorXoshiro { } HWY_INLINE VU64 Next() noexcept { - const ScalableTag tag{}; + const TagU64 tag{}; auto s0 = Load(tag, state_[{0}].data()); auto s1 = Load(tag, state_[{1}].data()); auto s2 = Load(tag, state_[{2}].data()); @@ -368,7 +371,7 @@ class CachedXoshiro { } private: - VectorXoshiro generator_; + VectorXoshiro<> generator_; alignas(HWY_ALIGNMENT) std::array cache_; std::size_t index_; diff --git a/hwy/contrib/random/random_test.cc b/hwy/contrib/random/random_test.cc index 44cdd30000..a5849bbaf6 100644 --- a/hwy/contrib/random/random_test.cc +++ b/hwy/contrib/random/random_test.cc @@ -30,7 +30,7 @@ std::uint64_t GetSeed() { return static_cast(std::time(nullptr)); } void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result, const size_t size) { const ScalableTag d; - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; for (size_t i = 0; i < size; i += Lanes(d)) { Store(generator(), d, result + i); } @@ -40,7 +40,7 @@ void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result, void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result, const size_t size) { const ScalableTag d; - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; for (size_t i = 0; i < size; i += Lanes(d)) { Store(generator.Uniform(), d, result + i); } @@ -49,7 +49,7 @@ void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result, void TestSeeding() { const std::uint64_t seed = GetSeed(); - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; internal::Xoshiro reference{seed}; const auto& state = generator.GetState(); const ScalableTag d; @@ -72,7 +72,7 @@ void TestSeeding() { void TestMultiThreadSeeding() { const std::uint64_t seed = GetSeed(); const std::uint64_t threadId = std::random_device()() % 1000; - VectorXoshiro generator{seed, threadId}; + VectorXoshiro<> generator{seed, threadId}; internal::Xoshiro reference{seed}; for (std::size_t i = 0UL; i < threadId; ++i) { @@ -146,7 +146,7 @@ void TestUniformDist() { void TestNextNRandomUint64() { const std::uint64_t seed = GetSeed(); - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; const auto result_array = generator.operator()(tests); std::vector reference; reference.emplace_back(seed); @@ -174,7 +174,7 @@ void TestNextNRandomUint64() { void TestNextFixedNRandomUint64() { const std::uint64_t seed = GetSeed(); - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; const auto result_array = generator.operator()(); std::vector reference; reference.emplace_back(seed); @@ -203,7 +203,7 @@ void TestNextFixedNRandomUint64() { #if HWY_HAVE_FLOAT64 void TestNextNUniformDist() { const std::uint64_t seed = GetSeed(); - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; const auto result_array = generator.Uniform(tests); internal::Xoshiro reference{seed}; const ScalableTag d; @@ -222,7 +222,7 @@ void TestNextNUniformDist() { void TestNextFixedNUniformDist() { const std::uint64_t seed = GetSeed(); - VectorXoshiro generator{seed}; + VectorXoshiro<> generator{seed}; const auto result_array = generator.Uniform(); internal::Xoshiro reference{seed}; const ScalableTag d;