ESPResSo
Extensible Simulation Package for Research on Soft Matter Systems
Loading...
Searching...
No Matches
electrokinetics/generated_kernels/philox_rand.h
Go to the documentation of this file.
1/*
2Copyright 2010-2011, D. E. Shaw Research. All rights reserved.
3Copyright 2019-2024, Michael Kuron.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions are
7met:
8
9* Redistributions of source code must retain the above copyright
10 notice, this list of conditions, and the following disclaimer.
11
12* Redistributions in binary form must reproduce the above copyright
13 notice, this list of conditions, and the following disclaimer in the
14 documentation and/or other materials provided with the distribution.
15
16* Neither the name of the copyright holder nor the names of its
17 contributors may be used to endorse or promote products derived from
18 this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31*/
32
33// kernel generated with pystencils v1.4, lbmpy v1.4+1.g3fc1c8f.dirty, sympy
34// v1.12.1, lbmpy_walberla/pystencils_walberla from waLBerla commit
35// 007e77e077ad9d22b5eed6f3d3118240993e553c
36
37/**
38 * @file
39 * Philox counter-based RNG from @cite salmon11a.
40 * Adapted from the pystencils source file
41 * https://i10git.cs.fau.de/pycodegen/pystencils/-/blob/b4d7ef7cb5b499f3fa55ebfcd598ac7d6e11a3db/src/pystencils/include/philox_rand.h
42 */
43
44#pragma once
45
46#if !defined(__OPENCL_VERSION__) && !defined(__HIPCC_RTC__)
47#if defined(__SSE2__) || (defined(_MSC_VER) && !defined(_M_ARM64))
48#include <emmintrin.h> // SSE2
49#endif
50#ifdef __AVX2__
51#include <immintrin.h> // AVX*
52#elif defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
53#include <smmintrin.h> // SSE4
54#ifdef __FMA__
55#include <immintrin.h> // FMA
56#endif
57#endif
58
59#if defined(_MSC_VER) && defined(_M_ARM64)
60#define __ARM_NEON
61#endif
62
63#ifdef __ARM_NEON
64#include <arm_neon.h>
65#endif
66#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
67#include <arm_sve.h>
68#endif
69
70#if defined(__powerpc__) && defined(__GNUC__) && !defined(__clang__) && \
71 !defined(__xlC__)
72#include <ppu_intrinsics.h>
73#endif
74#ifdef __ALTIVEC__
75#include <altivec.h>
76#undef bool
77#ifndef _ARCH_PWR8
78#include <pveclib/vec_int64_ppc.h>
79#endif
80#endif
81
82#ifdef __riscv_v
83#include <riscv_vector.h>
84#endif
85#endif
86
87#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_SVE)
88#define SVE_QUALIFIERS __arm_streaming_compatible
89#elif defined(__ARM_FEATURE_SME)
90#define SVE_QUALIFIERS __arm_streaming
91#else
92#define SVE_QUALIFIERS
93#endif
94
95#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \
96 defined(__clang__) && defined(__CUDA__)
97#define QUALIFIERS static __forceinline__ __device__
98#elif defined(__OPENCL_VERSION__)
99#define QUALIFIERS static inline
100#else
101#define QUALIFIERS inline
102#include "myintrin.h"
103#endif
104
105#define PHILOX_W32_0 (0x9E3779B9)
106#define PHILOX_W32_1 (0xBB67AE85)
107#define PHILOX_M4x32_0 (0xD2511F53)
108#define PHILOX_M4x32_1 (0xCD9E8D57)
109#define TWOPOW53_INV_DOUBLE (1.1102230246251565e-16)
110#define TWOPOW32_INV_FLOAT (2.3283064e-10f)
111
112#ifdef __OPENCL_VERSION__
113#include "opencl_stdint.h"
114typedef uint32_t uint32;
115typedef uint64_t uint64;
116#else
117#ifndef __HIPCC_RTC__
118#include <cstdint>
119#endif
120typedef std::uint32_t uint32;
121typedef std::uint64_t uint64;
122#endif
123
124#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && \
125 __ARM_FEATURE_SVE_BITS > 0
126typedef svfloat32_t svfloat32_st
127 __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
128typedef svfloat64_t svfloat64_st
129 __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
130#elif defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
131typedef svfloat32_t svfloat32_st;
132typedef svfloat64_t svfloat64_st;
133#endif
134
136#if !defined(__CUDA_ARCH__) && !defined(__HIP_DEVICE_COMPILE__) && \
137 (!defined(__clang__) || !defined(__CUDA__))
138 // host code
139#if defined(__powerpc__) && (!defined(__clang__) || defined(__xlC__))
140 *hip = __mulhwu(a, b);
141 return a * b;
142#elif defined(__OPENCL_VERSION__)
143 *hip = mul_hi(a, b);
144 return a * b;
145#else
146 uint64 product = ((uint64)a) * ((uint64)b);
147 *hip = product >> 32;
148 return (uint32)product;
149#endif
150#else
151 // device code
152 *hip = __umulhi(a, b);
153 return a * b;
154#endif
155}
156
158 uint32 hi0;
159 uint32 hi1;
160 uint32 lo0 = mulhilo32(PHILOX_M4x32_0, ctr[0], &hi0);
161 uint32 lo1 = mulhilo32(PHILOX_M4x32_1, ctr[2], &hi1);
162
163 ctr[0] = hi1 ^ ctr[1] ^ key[0];
164 ctr[1] = lo1;
165 ctr[2] = hi0 ^ ctr[3] ^ key[1];
166 ctr[3] = lo0;
167}
168
170 key[0] += PHILOX_W32_0;
171 key[1] += PHILOX_W32_1;
172}
173
175 uint64 z = (uint64)x ^ ((uint64)y << (53 - 32));
176 return z * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE / 2.0);
177}
178
180 uint32 ctr3, uint32 key0, uint32 key1,
181#ifdef __OPENCL_VERSION__
182 double *rnd1, double *rnd2)
183#else
184 double &rnd1, double &rnd2)
185#endif
186{
187 uint32 key[2] = {key0, key1};
188 uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
189 _philox4x32round(ctr, key); // 1
191 _philox4x32round(ctr, key); // 2
193 _philox4x32round(ctr, key); // 3
195 _philox4x32round(ctr, key); // 4
197 _philox4x32round(ctr, key); // 5
199 _philox4x32round(ctr, key); // 6
201 _philox4x32round(ctr, key); // 7
203 _philox4x32round(ctr, key); // 8
205 _philox4x32round(ctr, key); // 9
207 _philox4x32round(ctr, key); // 10
208
209#ifdef __OPENCL_VERSION__
210 *rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
211 *rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
212#else
213 rnd1 = _uniform_double_hq(ctr[0], ctr[1]);
214 rnd2 = _uniform_double_hq(ctr[2], ctr[3]);
215#endif
216}
217
219 uint32 ctr3, uint32 key0, uint32 key1,
220#ifdef __OPENCL_VERSION__
221 float *rnd1, float *rnd2, float *rnd3,
222 float *rnd4)
223#else
224 float &rnd1, float &rnd2, float &rnd3,
225 float &rnd4)
226#endif
227{
228 uint32 key[2] = {key0, key1};
229 uint32 ctr[4] = {ctr0, ctr1, ctr2, ctr3};
230 _philox4x32round(ctr, key); // 1
232 _philox4x32round(ctr, key); // 2
234 _philox4x32round(ctr, key); // 3
236 _philox4x32round(ctr, key); // 4
238 _philox4x32round(ctr, key); // 5
240 _philox4x32round(ctr, key); // 6
242 _philox4x32round(ctr, key); // 7
244 _philox4x32round(ctr, key); // 8
246 _philox4x32round(ctr, key); // 9
248 _philox4x32round(ctr, key); // 10
249
250#ifdef __OPENCL_VERSION__
251 *rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
252 *rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
253 *rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
254 *rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
255#else
256 rnd1 = ctr[0] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
257 rnd2 = ctr[1] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
258 rnd3 = ctr[2] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
259 rnd4 = ctr[3] * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT / 2.0f);
260#endif
261}
262
263#if !defined(__CUDA_ARCH__) && !defined(__OPENCL_VERSION__) && \
264 !defined(__HIP_DEVICE_COMPILE__) && \
265 (!defined(__clang__) || !defined(__CUDA__))
266#if defined(__SSE4_1__) || (defined(_MSC_VER) && !defined(_M_ARM64))
267QUALIFIERS void _philox4x32round(__m128i *ctr, __m128i *key) {
268 __m128i lohi0a = _mm_mul_epu32(ctr[0], _mm_set1_epi32(PHILOX_M4x32_0));
269 __m128i lohi0b =
270 _mm_mul_epu32(_mm_srli_epi64(ctr[0], 32), _mm_set1_epi32(PHILOX_M4x32_0));
271 __m128i lohi1a = _mm_mul_epu32(ctr[2], _mm_set1_epi32(PHILOX_M4x32_1));
272 __m128i lohi1b =
273 _mm_mul_epu32(_mm_srli_epi64(ctr[2], 32), _mm_set1_epi32(PHILOX_M4x32_1));
274
275 lohi0a = _mm_shuffle_epi32(lohi0a, 0xD8);
276 lohi0b = _mm_shuffle_epi32(lohi0b, 0xD8);
277 lohi1a = _mm_shuffle_epi32(lohi1a, 0xD8);
278 lohi1b = _mm_shuffle_epi32(lohi1b, 0xD8);
279
280 __m128i lo0 = _mm_unpacklo_epi32(lohi0a, lohi0b);
281 __m128i hi0 = _mm_unpackhi_epi32(lohi0a, lohi0b);
282 __m128i lo1 = _mm_unpacklo_epi32(lohi1a, lohi1b);
283 __m128i hi1 = _mm_unpackhi_epi32(lohi1a, lohi1b);
284
285 ctr[0] = _mm_xor_si128(_mm_xor_si128(hi1, ctr[1]), key[0]);
286 ctr[1] = lo1;
287 ctr[2] = _mm_xor_si128(_mm_xor_si128(hi0, ctr[3]), key[1]);
288 ctr[3] = lo0;
289}
290
291QUALIFIERS void _philox4x32bumpkey(__m128i *key) {
292 key[0] = _mm_add_epi32(key[0], _mm_set1_epi32(PHILOX_W32_0));
293 key[1] = _mm_add_epi32(key[1], _mm_set1_epi32(PHILOX_W32_1));
294}
295
296template <bool high>
297QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) {
298 // convert 32 to 64 bit
299 if (high) {
300 x = _mm_unpackhi_epi32(x, _mm_setzero_si128());
301 y = _mm_unpackhi_epi32(y, _mm_setzero_si128());
302 } else {
303 x = _mm_unpacklo_epi32(x, _mm_setzero_si128());
304 y = _mm_unpacklo_epi32(y, _mm_setzero_si128());
305 }
306
307 // calculate z = x ^ y << (53 - 32))
308 __m128i z = _mm_sll_epi64(y, _mm_set1_epi64x(53 - 32));
309 z = _mm_xor_si128(x, z);
310
311 // convert uint64 to double
312 __m128d rs = _my_cvtepu64_pd(z);
313 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
314#ifdef __FMA__
315 rs = _mm_fmadd_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE),
316 _mm_set1_pd(TWOPOW53_INV_DOUBLE / 2.0));
317#else
318 rs = _mm_mul_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE));
319 rs = _mm_add_pd(rs, _mm_set1_pd(TWOPOW53_INV_DOUBLE / 2.0));
320#endif
321
322 return rs;
323}
324
325QUALIFIERS void philox_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2,
326 __m128i ctr3, uint32 key0, uint32 key1,
327 __m128 &rnd1, __m128 &rnd2, __m128 &rnd3,
328 __m128 &rnd4) {
329 __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
330 __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
331 _philox4x32round(ctr, key); // 1
333 _philox4x32round(ctr, key); // 2
335 _philox4x32round(ctr, key); // 3
337 _philox4x32round(ctr, key); // 4
339 _philox4x32round(ctr, key); // 5
341 _philox4x32round(ctr, key); // 6
343 _philox4x32round(ctr, key); // 7
345 _philox4x32round(ctr, key); // 8
347 _philox4x32round(ctr, key); // 9
349 _philox4x32round(ctr, key); // 10
350
351 // convert uint32 to float
352 rnd1 = _my_cvtepu32_ps(ctr[0]);
353 rnd2 = _my_cvtepu32_ps(ctr[1]);
354 rnd3 = _my_cvtepu32_ps(ctr[2]);
355 rnd4 = _my_cvtepu32_ps(ctr[3]);
356 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
357#ifdef __FMA__
358 rnd1 = _mm_fmadd_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT),
359 _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
360 rnd2 = _mm_fmadd_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT),
361 _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
362 rnd3 = _mm_fmadd_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT),
363 _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
364 rnd4 = _mm_fmadd_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT),
365 _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
366#else
367 rnd1 = _mm_mul_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT));
368 rnd1 = _mm_add_ps(rnd1, _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
369 rnd2 = _mm_mul_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT));
370 rnd2 = _mm_add_ps(rnd2, _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
371 rnd3 = _mm_mul_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT));
372 rnd3 = _mm_add_ps(rnd3, _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
373 rnd4 = _mm_mul_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT));
374 rnd4 = _mm_add_ps(rnd4, _mm_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
375#endif
376}
377
378QUALIFIERS void philox_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2,
379 __m128i ctr3, uint32 key0, uint32 key1,
380 __m128d &rnd1lo, __m128d &rnd1hi,
381 __m128d &rnd2lo, __m128d &rnd2hi) {
382 __m128i key[2] = {_mm_set1_epi32(key0), _mm_set1_epi32(key1)};
383 __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
384 _philox4x32round(ctr, key); // 1
386 _philox4x32round(ctr, key); // 2
388 _philox4x32round(ctr, key); // 3
390 _philox4x32round(ctr, key); // 4
392 _philox4x32round(ctr, key); // 5
394 _philox4x32round(ctr, key); // 6
396 _philox4x32round(ctr, key); // 7
398 _philox4x32round(ctr, key); // 8
400 _philox4x32round(ctr, key); // 9
402 _philox4x32round(ctr, key); // 10
403
404 rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
405 rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
406 rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
407 rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
408}
409
410QUALIFIERS void philox_float4(uint32 ctr0, __m128i ctr1, uint32 ctr2,
411 uint32 ctr3, uint32 key0, uint32 key1,
412 __m128 &rnd1, __m128 &rnd2, __m128 &rnd3,
413 __m128 &rnd4) {
414 __m128i ctr0v = _mm_set1_epi32(ctr0);
415 __m128i ctr2v = _mm_set1_epi32(ctr2);
416 __m128i ctr3v = _mm_set1_epi32(ctr3);
417
418 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
419}
420
421QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2,
422 uint32 ctr3, uint32 key0, uint32 key1,
423 __m128d &rnd1lo, __m128d &rnd1hi,
424 __m128d &rnd2lo, __m128d &rnd2hi) {
425 __m128i ctr0v = _mm_set1_epi32(ctr0);
426 __m128i ctr2v = _mm_set1_epi32(ctr2);
427 __m128i ctr3v = _mm_set1_epi32(ctr3);
428
429 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
430 rnd2hi);
431}
432
433QUALIFIERS void philox_double2(uint32 ctr0, __m128i ctr1, uint32 ctr2,
434 uint32 ctr3, uint32 key0, uint32 key1,
435 __m128d &rnd1, __m128d &rnd2) {
436 __m128i ctr0v = _mm_set1_epi32(ctr0);
437 __m128i ctr2v = _mm_set1_epi32(ctr2);
438 __m128i ctr3v = _mm_set1_epi32(ctr3);
439
440 __m128d ignore;
441 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2,
442 ignore);
443}
444#endif
445
446#ifdef __ALTIVEC__
447QUALIFIERS void _philox4x32round(__vector unsigned int *ctr,
448 __vector unsigned int *key) {
449#ifndef _ARCH_PWR8
450 __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
451 __vector unsigned int hi0 = vec_mulhuw(ctr[0], vec_splats(PHILOX_M4x32_0));
452 __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
453 __vector unsigned int hi1 = vec_mulhuw(ctr[2], vec_splats(PHILOX_M4x32_1));
454#elif defined(_ARCH_PWR10)
455 __vector unsigned int lo0 = vec_mul(ctr[0], vec_splats(PHILOX_M4x32_0));
456 __vector unsigned int hi0 = vec_mulh(ctr[0], vec_splats(PHILOX_M4x32_0));
457 __vector unsigned int lo1 = vec_mul(ctr[2], vec_splats(PHILOX_M4x32_1));
458 __vector unsigned int hi1 = vec_mulh(ctr[2], vec_splats(PHILOX_M4x32_1));
459#else
460 __vector unsigned int lohi0a =
461 (__vector unsigned int)vec_mule(ctr[0], vec_splats(PHILOX_M4x32_0));
462 __vector unsigned int lohi0b =
463 (__vector unsigned int)vec_mulo(ctr[0], vec_splats(PHILOX_M4x32_0));
464 __vector unsigned int lohi1a =
465 (__vector unsigned int)vec_mule(ctr[2], vec_splats(PHILOX_M4x32_1));
466 __vector unsigned int lohi1b =
467 (__vector unsigned int)vec_mulo(ctr[2], vec_splats(PHILOX_M4x32_1));
468
469#ifdef __LITTLE_ENDIAN__
470 __vector unsigned int lo0 = vec_mergee(lohi0a, lohi0b);
471 __vector unsigned int lo1 = vec_mergee(lohi1a, lohi1b);
472 __vector unsigned int hi0 = vec_mergeo(lohi0a, lohi0b);
473 __vector unsigned int hi1 = vec_mergeo(lohi1a, lohi1b);
474#else
475 __vector unsigned int lo0 = vec_mergeo(lohi0a, lohi0b);
476 __vector unsigned int lo1 = vec_mergeo(lohi1a, lohi1b);
477 __vector unsigned int hi0 = vec_mergee(lohi0a, lohi0b);
478 __vector unsigned int hi1 = vec_mergee(lohi1a, lohi1b);
479#endif
480#endif
481
482 ctr[0] = vec_xor(vec_xor(hi1, ctr[1]), key[0]);
483 ctr[1] = lo1;
484 ctr[2] = vec_xor(vec_xor(hi0, ctr[3]), key[1]);
485 ctr[3] = lo0;
486}
487
488QUALIFIERS void _philox4x32bumpkey(__vector unsigned int *key) {
489 key[0] = vec_add(key[0], vec_splats(PHILOX_W32_0));
490 key[1] = vec_add(key[1], vec_splats(PHILOX_W32_1));
491}
492
493#ifdef __VSX__
494template <bool high>
495QUALIFIERS __vector double _uniform_double_hq(__vector unsigned int x,
496 __vector unsigned int y) {
497 // convert 32 to 64 bit
498#ifdef __LITTLE_ENDIAN__
499 if (high) {
500 x = vec_mergel(x, vec_splats(0U));
501 y = vec_mergel(y, vec_splats(0U));
502 } else {
503 x = vec_mergeh(x, vec_splats(0U));
504 y = vec_mergeh(y, vec_splats(0U));
505 }
506#else
507 if (high) {
508 x = vec_mergel(vec_splats(0U), x);
509 y = vec_mergel(vec_splats(0U), y);
510 } else {
511 x = vec_mergeh(vec_splats(0U), x);
512 y = vec_mergeh(vec_splats(0U), y);
513 }
514#endif
515
516 // calculate z = x ^ y << (53 - 32))
517#ifdef _ARCH_PWR8
518 __vector unsigned long long z =
519 vec_sl((__vector unsigned long long)y, vec_splats(53ULL - 32ULL));
520#else
521 __vector unsigned long long z =
522 vec_vsld((__vector unsigned long long)y, vec_splats(53ULL - 32ULL));
523#endif
524 z = vec_xor((__vector unsigned long long)x, z);
525
526 // convert uint64 to double
527#ifdef __xlC__
528 __vector double rs = vec_ctd(z, 0);
529#else
530 __vector double rs = vec_ctf(z, 0);
531#endif
532 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
533 rs = vec_madd(rs, vec_splats(TWOPOW53_INV_DOUBLE),
534 vec_splats(TWOPOW53_INV_DOUBLE / 2.0));
535
536 return rs;
537}
538#endif
539
540QUALIFIERS void philox_float4(__vector unsigned int ctr0,
541 __vector unsigned int ctr1,
542 __vector unsigned int ctr2,
543 __vector unsigned int ctr3, uint32 key0,
544 uint32 key1, __vector float &rnd1,
545 __vector float &rnd2, __vector float &rnd3,
546 __vector float &rnd4) {
547 __vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
548 __vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
549 _philox4x32round(ctr, key); // 1
551 _philox4x32round(ctr, key); // 2
553 _philox4x32round(ctr, key); // 3
555 _philox4x32round(ctr, key); // 4
557 _philox4x32round(ctr, key); // 5
559 _philox4x32round(ctr, key); // 6
561 _philox4x32round(ctr, key); // 7
563 _philox4x32round(ctr, key); // 8
565 _philox4x32round(ctr, key); // 9
567 _philox4x32round(ctr, key); // 10
568
569 // convert uint32 to float
570 rnd1 = vec_ctf(ctr[0], 0);
571 rnd2 = vec_ctf(ctr[1], 0);
572 rnd3 = vec_ctf(ctr[2], 0);
573 rnd4 = vec_ctf(ctr[3], 0);
574 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
575 rnd1 = vec_madd(rnd1, vec_splats(TWOPOW32_INV_FLOAT),
576 vec_splats(TWOPOW32_INV_FLOAT / 2.0f));
577 rnd2 = vec_madd(rnd2, vec_splats(TWOPOW32_INV_FLOAT),
578 vec_splats(TWOPOW32_INV_FLOAT / 2.0f));
579 rnd3 = vec_madd(rnd3, vec_splats(TWOPOW32_INV_FLOAT),
580 vec_splats(TWOPOW32_INV_FLOAT / 2.0f));
581 rnd4 = vec_madd(rnd4, vec_splats(TWOPOW32_INV_FLOAT),
582 vec_splats(TWOPOW32_INV_FLOAT / 2.0f));
583}
584
585#ifdef __VSX__
586QUALIFIERS void philox_double2(__vector unsigned int ctr0,
587 __vector unsigned int ctr1,
588 __vector unsigned int ctr2,
589 __vector unsigned int ctr3, uint32 key0,
590 uint32 key1, __vector double &rnd1lo,
591 __vector double &rnd1hi, __vector double &rnd2lo,
592 __vector double &rnd2hi) {
593 __vector unsigned int key[2] = {vec_splats(key0), vec_splats(key1)};
594 __vector unsigned int ctr[4] = {ctr0, ctr1, ctr2, ctr3};
595 _philox4x32round(ctr, key); // 1
597 _philox4x32round(ctr, key); // 2
599 _philox4x32round(ctr, key); // 3
601 _philox4x32round(ctr, key); // 4
603 _philox4x32round(ctr, key); // 5
605 _philox4x32round(ctr, key); // 6
607 _philox4x32round(ctr, key); // 7
609 _philox4x32round(ctr, key); // 8
611 _philox4x32round(ctr, key); // 9
613 _philox4x32round(ctr, key); // 10
614
615 rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
616 rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
617 rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
618 rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
619}
620#endif
621
622QUALIFIERS void philox_float4(uint32 ctr0, __vector unsigned int ctr1,
623 uint32 ctr2, uint32 ctr3, uint32 key0,
624 uint32 key1, __vector float &rnd1,
625 __vector float &rnd2, __vector float &rnd3,
626 __vector float &rnd4) {
627 __vector unsigned int ctr0v = vec_splats(ctr0);
628 __vector unsigned int ctr2v = vec_splats(ctr2);
629 __vector unsigned int ctr3v = vec_splats(ctr3);
630
631 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
632}
633
634QUALIFIERS void philox_float4(uint32 ctr0, __vector int ctr1, uint32 ctr2,
635 uint32 ctr3, uint32 key0, uint32 key1,
636 __vector float &rnd1, __vector float &rnd2,
637 __vector float &rnd3, __vector float &rnd4) {
638 philox_float4(ctr0, (__vector unsigned int)ctr1, ctr2, ctr3, key0, key1, rnd1,
639 rnd2, rnd3, rnd4);
640}
641
642#ifdef __VSX__
643QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1,
644 uint32 ctr2, uint32 ctr3, uint32 key0,
645 uint32 key1, __vector double &rnd1lo,
646 __vector double &rnd1hi, __vector double &rnd2lo,
647 __vector double &rnd2hi) {
648 __vector unsigned int ctr0v = vec_splats(ctr0);
649 __vector unsigned int ctr2v = vec_splats(ctr2);
650 __vector unsigned int ctr3v = vec_splats(ctr3);
651
652 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
653 rnd2hi);
654}
655
656QUALIFIERS void philox_double2(uint32 ctr0, __vector unsigned int ctr1,
657 uint32 ctr2, uint32 ctr3, uint32 key0,
658 uint32 key1, __vector double &rnd1,
659 __vector double &rnd2) {
660 __vector unsigned int ctr0v = vec_splats(ctr0);
661 __vector unsigned int ctr2v = vec_splats(ctr2);
662 __vector unsigned int ctr3v = vec_splats(ctr3);
663
664 __vector double ignore;
665 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2,
666 ignore);
667}
668
669QUALIFIERS void philox_double2(uint32 ctr0, __vector int ctr1, uint32 ctr2,
670 uint32 ctr3, uint32 key0, uint32 key1,
671 __vector double &rnd1, __vector double &rnd2) {
672 philox_double2(ctr0, (__vector unsigned int)ctr1, ctr2, ctr3, key0, key1,
673 rnd1, rnd2);
674}
675#endif
676#endif
677
678#if defined(__ARM_NEON)
679QUALIFIERS void _philox4x32round(uint32x4_t *ctr, uint32x4_t *key) {
680 uint32x4_t lohi0a = vreinterpretq_u32_u64(
681 vmull_u32(vget_low_u32(ctr[0]), vdup_n_u32(PHILOX_M4x32_0)));
682 uint32x4_t lohi0b = vreinterpretq_u32_u64(
683 vmull_high_u32(ctr[0], vdupq_n_u32(PHILOX_M4x32_0)));
684 uint32x4_t lohi1a = vreinterpretq_u32_u64(
685 vmull_u32(vget_low_u32(ctr[2]), vdup_n_u32(PHILOX_M4x32_1)));
686 uint32x4_t lohi1b = vreinterpretq_u32_u64(
687 vmull_high_u32(ctr[2], vdupq_n_u32(PHILOX_M4x32_1)));
688
689 uint32x4_t lo0 = vuzp1q_u32(lohi0a, lohi0b);
690 uint32x4_t lo1 = vuzp1q_u32(lohi1a, lohi1b);
691 uint32x4_t hi0 = vuzp2q_u32(lohi0a, lohi0b);
692 uint32x4_t hi1 = vuzp2q_u32(lohi1a, lohi1b);
693
694 ctr[0] = veorq_u32(veorq_u32(hi1, ctr[1]), key[0]);
695 ctr[1] = lo1;
696 ctr[2] = veorq_u32(veorq_u32(hi0, ctr[3]), key[1]);
697 ctr[3] = lo0;
698}
699
700QUALIFIERS void _philox4x32bumpkey(uint32x4_t *key) {
701 key[0] = vaddq_u32(key[0], vdupq_n_u32(PHILOX_W32_0));
702 key[1] = vaddq_u32(key[1], vdupq_n_u32(PHILOX_W32_1));
703}
704
705template <bool high>
706QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y) {
707 // convert 32 to 64 bit
708 if (high) {
709 x = vzip2q_u32(x, vdupq_n_u32(0));
710 y = vzip2q_u32(y, vdupq_n_u32(0));
711 } else {
712 x = vzip1q_u32(x, vdupq_n_u32(0));
713 y = vzip1q_u32(y, vdupq_n_u32(0));
714 }
715
716 // calculate z = x ^ y << (53 - 32))
717 uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
718 z = veorq_u64(vreinterpretq_u64_u32(x), z);
719
720 // convert uint64 to double
721 float64x2_t rs = vcvtq_f64_u64(z);
722 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
723 rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE / 2.0),
724 vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
725
726 return rs;
727}
728
729QUALIFIERS void philox_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2,
730 uint32x4_t ctr3, uint32 key0, uint32 key1,
731 float32x4_t &rnd1, float32x4_t &rnd2,
732 float32x4_t &rnd3, float32x4_t &rnd4) {
733 uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
734 uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
735 _philox4x32round(ctr, key); // 1
737 _philox4x32round(ctr, key); // 2
739 _philox4x32round(ctr, key); // 3
741 _philox4x32round(ctr, key); // 4
743 _philox4x32round(ctr, key); // 5
745 _philox4x32round(ctr, key); // 6
747 _philox4x32round(ctr, key); // 7
749 _philox4x32round(ctr, key); // 8
751 _philox4x32round(ctr, key); // 9
753 _philox4x32round(ctr, key); // 10
754
755 // convert uint32 to float
756 rnd1 = vcvtq_f32_u32(ctr[0]);
757 rnd2 = vcvtq_f32_u32(ctr[1]);
758 rnd3 = vcvtq_f32_u32(ctr[2]);
759 rnd4 = vcvtq_f32_u32(ctr[3]);
760 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
761 rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT / 2.0),
762 vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1);
763 rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT / 2.0),
764 vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2);
765 rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT / 2.0),
766 vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3);
767 rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT / 2.0),
768 vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4);
769}
770
771QUALIFIERS void philox_double2(uint32x4_t ctr0, uint32x4_t ctr1,
772 uint32x4_t ctr2, uint32x4_t ctr3, uint32 key0,
773 uint32 key1, float64x2_t &rnd1lo,
774 float64x2_t &rnd1hi, float64x2_t &rnd2lo,
775 float64x2_t &rnd2hi) {
776 uint32x4_t key[2] = {vdupq_n_u32(key0), vdupq_n_u32(key1)};
777 uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
778 _philox4x32round(ctr, key); // 1
780 _philox4x32round(ctr, key); // 2
782 _philox4x32round(ctr, key); // 3
784 _philox4x32round(ctr, key); // 4
786 _philox4x32round(ctr, key); // 5
788 _philox4x32round(ctr, key); // 6
790 _philox4x32round(ctr, key); // 7
792 _philox4x32round(ctr, key); // 8
794 _philox4x32round(ctr, key); // 9
796 _philox4x32round(ctr, key); // 10
797
798 rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
799 rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
800 rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
801 rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
802}
803
804QUALIFIERS void philox_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2,
805 uint32 ctr3, uint32 key0, uint32 key1,
806 float32x4_t &rnd1, float32x4_t &rnd2,
807 float32x4_t &rnd3, float32x4_t &rnd4) {
808 uint32x4_t ctr0v = vdupq_n_u32(ctr0);
809 uint32x4_t ctr2v = vdupq_n_u32(ctr2);
810 uint32x4_t ctr3v = vdupq_n_u32(ctr3);
811
812 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
813}
814
815#ifndef _MSC_VER
816QUALIFIERS void philox_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2,
817 uint32 ctr3, uint32 key0, uint32 key1,
818 float32x4_t &rnd1, float32x4_t &rnd2,
819 float32x4_t &rnd3, float32x4_t &rnd4) {
820 philox_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1,
821 rnd2, rnd3, rnd4);
822}
823#endif
824
825QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2,
826 uint32 ctr3, uint32 key0, uint32 key1,
827 float64x2_t &rnd1lo, float64x2_t &rnd1hi,
828 float64x2_t &rnd2lo, float64x2_t &rnd2hi) {
829 uint32x4_t ctr0v = vdupq_n_u32(ctr0);
830 uint32x4_t ctr2v = vdupq_n_u32(ctr2);
831 uint32x4_t ctr3v = vdupq_n_u32(ctr3);
832
833 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
834 rnd2hi);
835}
836
837QUALIFIERS void philox_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2,
838 uint32 ctr3, uint32 key0, uint32 key1,
839 float64x2_t &rnd1, float64x2_t &rnd2) {
840 uint32x4_t ctr0v = vdupq_n_u32(ctr0);
841 uint32x4_t ctr2v = vdupq_n_u32(ctr2);
842 uint32x4_t ctr3v = vdupq_n_u32(ctr3);
843
844 float64x2_t ignore;
845 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2,
846 ignore);
847}
848
849#ifndef _MSC_VER
850QUALIFIERS void philox_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2,
851 uint32 ctr3, uint32 key0, uint32 key1,
852 float64x2_t &rnd1, float64x2_t &rnd2) {
853 philox_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1,
854 rnd1, rnd2);
855}
856#endif
857#endif
858
859#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_SME)
860QUALIFIERS void _philox4x32round(svuint32x4_t &ctr,
861 svuint32x2_t &key) SVE_QUALIFIERS {
862 svuint32_t lo0 =
863 svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 0), svdup_u32(PHILOX_M4x32_0));
864 svuint32_t hi0 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 0),
865 svdup_u32(PHILOX_M4x32_0));
866 svuint32_t lo1 =
867 svmul_u32_x(svptrue_b32(), svget4_u32(ctr, 2), svdup_u32(PHILOX_M4x32_1));
868 svuint32_t hi1 = svmulh_u32_x(svptrue_b32(), svget4_u32(ctr, 2),
869 svdup_u32(PHILOX_M4x32_1));
870
871 ctr = svset4_u32(
872 ctr, 0,
873 sveor_u32_x(svptrue_b32(),
874 sveor_u32_x(svptrue_b32(), hi1, svget4_u32(ctr, 1)),
875 svget2_u32(key, 0)));
876 ctr = svset4_u32(ctr, 1, lo1);
877 ctr = svset4_u32(
878 ctr, 2,
879 sveor_u32_x(svptrue_b32(),
880 sveor_u32_x(svptrue_b32(), hi0, svget4_u32(ctr, 3)),
881 svget2_u32(key, 1)));
882 ctr = svset4_u32(ctr, 3, lo0);
883}
884
885QUALIFIERS void _philox4x32bumpkey(svuint32x2_t &key) SVE_QUALIFIERS {
886 key = svset2_u32(
887 key, 0,
888 svadd_u32_x(svptrue_b32(), svget2_u32(key, 0), svdup_u32(PHILOX_W32_0)));
889 key = svset2_u32(
890 key, 1,
891 svadd_u32_x(svptrue_b32(), svget2_u32(key, 1), svdup_u32(PHILOX_W32_1)));
892}
893
894template <bool high>
895QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x,
896 svuint32_t y) SVE_QUALIFIERS {
897 // convert 32 to 64 bit
898 if (high) {
899 x = svzip2_u32(x, svdup_u32(0));
900 y = svzip2_u32(y, svdup_u32(0));
901 } else {
902 x = svzip1_u32(x, svdup_u32(0));
903 y = svzip1_u32(y, svdup_u32(0));
904 }
905
906 // calculate z = x ^ y << (53 - 32))
907 svuint64_t z =
908 svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
909 z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
910
911 // convert uint64 to double
912 svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z);
913 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
914 rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE),
915 svdup_f64(TWOPOW53_INV_DOUBLE / 2.0));
916
917 return rs;
918}
919
920QUALIFIERS void philox_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2,
921 svuint32_t ctr3, uint32 key0, uint32 key1,
922 svfloat32_st &rnd1, svfloat32_st &rnd2,
923 svfloat32_st &rnd3,
924 svfloat32_st &rnd4) SVE_QUALIFIERS {
925 svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
926 svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
927 _philox4x32round(ctr, key); // 1
929 _philox4x32round(ctr, key); // 2
931 _philox4x32round(ctr, key); // 3
933 _philox4x32round(ctr, key); // 4
935 _philox4x32round(ctr, key); // 5
937 _philox4x32round(ctr, key); // 6
939 _philox4x32round(ctr, key); // 7
941 _philox4x32round(ctr, key); // 8
943 _philox4x32round(ctr, key); // 9
945 _philox4x32round(ctr, key); // 10
946
947 // convert uint32 to float
948 rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0));
949 rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1));
950 rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2));
951 rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3));
952 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
953 rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT),
954 svdup_f32(TWOPOW32_INV_FLOAT / 2.0));
955 rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT),
956 svdup_f32(TWOPOW32_INV_FLOAT / 2.0));
957 rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT),
958 svdup_f32(TWOPOW32_INV_FLOAT / 2.0));
959 rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT),
960 svdup_f32(TWOPOW32_INV_FLOAT / 2.0));
961}
962
963QUALIFIERS void philox_double2(svuint32_t ctr0, svuint32_t ctr1,
964 svuint32_t ctr2, svuint32_t ctr3, uint32 key0,
965 uint32 key1, svfloat64_st &rnd1lo,
966 svfloat64_st &rnd1hi, svfloat64_st &rnd2lo,
967 svfloat64_st &rnd2hi) SVE_QUALIFIERS {
968 svuint32x2_t key = svcreate2_u32(svdup_u32(key0), svdup_u32(key1));
969 svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
970 _philox4x32round(ctr, key); // 1
972 _philox4x32round(ctr, key); // 2
974 _philox4x32round(ctr, key); // 3
976 _philox4x32round(ctr, key); // 4
978 _philox4x32round(ctr, key); // 5
980 _philox4x32round(ctr, key); // 6
982 _philox4x32round(ctr, key); // 7
984 _philox4x32round(ctr, key); // 8
986 _philox4x32round(ctr, key); // 9
988 _philox4x32round(ctr, key); // 10
989
990 rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
991 rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
992 rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
993 rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
994}
995
996QUALIFIERS void philox_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2,
997 uint32 ctr3, uint32 key0, uint32 key1,
998 svfloat32_st &rnd1, svfloat32_st &rnd2,
999 svfloat32_st &rnd3,
1000 svfloat32_st &rnd4) SVE_QUALIFIERS {
1001 svuint32_t ctr0v = svdup_u32(ctr0);
1002 svuint32_t ctr2v = svdup_u32(ctr2);
1003 svuint32_t ctr3v = svdup_u32(ctr3);
1004
1005 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
1006}
1007
1008QUALIFIERS void philox_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2,
1009 uint32 ctr3, uint32 key0, uint32 key1,
1010 svfloat32_st &rnd1, svfloat32_st &rnd2,
1011 svfloat32_st &rnd3,
1012 svfloat32_st &rnd4) SVE_QUALIFIERS {
1013 philox_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, rnd1,
1014 rnd2, rnd3, rnd4);
1015}
1016
1017QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2,
1018 uint32 ctr3, uint32 key0, uint32 key1,
1019 svfloat64_st &rnd1lo, svfloat64_st &rnd1hi,
1020 svfloat64_st &rnd2lo,
1021 svfloat64_st &rnd2hi) SVE_QUALIFIERS {
1022 svuint32_t ctr0v = svdup_u32(ctr0);
1023 svuint32_t ctr2v = svdup_u32(ctr2);
1024 svuint32_t ctr3v = svdup_u32(ctr3);
1025
1026 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
1027 rnd2hi);
1028}
1029
1030QUALIFIERS void philox_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2,
1031 uint32 ctr3, uint32 key0, uint32 key1,
1032 svfloat64_st &rnd1,
1033 svfloat64_st &rnd2) SVE_QUALIFIERS {
1034 svuint32_t ctr0v = svdup_u32(ctr0);
1035 svuint32_t ctr2v = svdup_u32(ctr2);
1036 svuint32_t ctr3v = svdup_u32(ctr3);
1037
1038 svfloat64_st ignore;
1039 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2,
1040 ignore);
1041}
1042
1043QUALIFIERS void philox_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2,
1044 uint32 ctr3, uint32 key0, uint32 key1,
1045 svfloat64_st &rnd1,
1046 svfloat64_st &rnd2) SVE_QUALIFIERS {
1047 philox_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1,
1048 rnd1, rnd2);
1049}
1050#endif
1051
1052#if defined(__riscv_v)
1053QUALIFIERS void _philox4x32round(vuint32m1_t &ctr0, vuint32m1_t &ctr1,
1054 vuint32m1_t &ctr2, vuint32m1_t &ctr3,
1055 vuint32m1_t key0, vuint32m1_t key1) {
1056 vuint32m1_t lo0 = __riscv_vmul_vv_u32m1(
1057 ctr0, __riscv_vmv_v_x_u32m1(PHILOX_M4x32_0, __riscv_vsetvlmax_e32m1()),
1058 __riscv_vsetvlmax_e32m1());
1059 vuint32m1_t hi0 = __riscv_vmulhu_vv_u32m1(
1060 ctr0, __riscv_vmv_v_x_u32m1(PHILOX_M4x32_0, __riscv_vsetvlmax_e32m1()),
1061 __riscv_vsetvlmax_e32m1());
1062 vuint32m1_t lo1 = __riscv_vmul_vv_u32m1(
1063 ctr2, __riscv_vmv_v_x_u32m1(PHILOX_M4x32_1, __riscv_vsetvlmax_e32m1()),
1064 __riscv_vsetvlmax_e32m1());
1065 vuint32m1_t hi1 = __riscv_vmulhu_vv_u32m1(
1066 ctr2, __riscv_vmv_v_x_u32m1(PHILOX_M4x32_1, __riscv_vsetvlmax_e32m1()),
1067 __riscv_vsetvlmax_e32m1());
1068
1069 ctr0 = __riscv_vxor_vv_u32m1(
1070 __riscv_vxor_vv_u32m1(hi1, ctr1, __riscv_vsetvlmax_e32m1()), key0,
1071 __riscv_vsetvlmax_e32m1());
1072 ctr1 = lo1;
1073 ctr2 = __riscv_vxor_vv_u32m1(
1074 __riscv_vxor_vv_u32m1(hi0, ctr3, __riscv_vsetvlmax_e32m1()), key1,
1075 __riscv_vsetvlmax_e32m1());
1076 ctr3 = lo0;
1077}
1078
1079QUALIFIERS void _philox4x32bumpkey(vuint32m1_t &key0, vuint32m1_t &key1) {
1080 key0 = __riscv_vadd_vv_u32m1(
1081 key0, __riscv_vmv_v_x_u32m1(PHILOX_W32_0, __riscv_vsetvlmax_e32m1()),
1082 __riscv_vsetvlmax_e32m1());
1083 key1 = __riscv_vadd_vv_u32m1(
1084 key1, __riscv_vmv_v_x_u32m1(PHILOX_W32_1, __riscv_vsetvlmax_e32m1()),
1085 __riscv_vsetvlmax_e32m1());
1086}
1087
1088template <bool high>
1089QUALIFIERS vfloat64m1_t _uniform_double_hq(vuint32m1_t x, vuint32m1_t y) {
1090 // convert 32 to 64 bit
1091 if (high) {
1092 size_t s = __riscv_vsetvlmax_e32m1();
1093 x = __riscv_vslidedown_vx_u32m1(x, s / 2, s);
1094 y = __riscv_vslidedown_vx_u32m1(y, s / 2, s);
1095 }
1096 vuint64m1_t x64 = __riscv_vwcvtu_x_x_v_u64m1(
1097 __riscv_vlmul_trunc_v_u32m1_u32mf2(x), __riscv_vsetvlmax_e64m1());
1098 vuint64m1_t y64 = __riscv_vwcvtu_x_x_v_u64m1(
1099 __riscv_vlmul_trunc_v_u32m1_u32mf2(y), __riscv_vsetvlmax_e64m1());
1100
1101 // calculate z = x ^ y << (53 - 32))
1102 vuint64m1_t z =
1103 __riscv_vsll_vx_u64m1(y64, 53 - 32, __riscv_vsetvlmax_e64m1());
1104 z = __riscv_vxor_vv_u64m1(x64, z, __riscv_vsetvlmax_e64m1());
1105
1106 // convert uint64 to double
1107 vfloat64m1_t rs = __riscv_vfcvt_f_xu_v_f64m1(z, __riscv_vsetvlmax_e64m1());
1108 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
1109 rs = __riscv_vfmadd_vv_f64m1(
1110 rs,
1111 __riscv_vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE, __riscv_vsetvlmax_e64m1()),
1112 __riscv_vfmv_v_f_f64m1(TWOPOW53_INV_DOUBLE / 2.0,
1113 __riscv_vsetvlmax_e64m1()),
1114 __riscv_vsetvlmax_e64m1());
1115
1116 return rs;
1117}
1118
1119QUALIFIERS void philox_float4(vuint32m1_t ctr0, vuint32m1_t ctr1,
1120 vuint32m1_t ctr2, vuint32m1_t ctr3, uint32 key0,
1121 uint32 key1, vfloat32m1_t &rnd1,
1122 vfloat32m1_t &rnd2, vfloat32m1_t &rnd3,
1123 vfloat32m1_t &rnd4) {
1124 vuint32m1_t key0v = __riscv_vmv_v_x_u32m1(key0, __riscv_vsetvlmax_e32m1());
1125 vuint32m1_t key1v = __riscv_vmv_v_x_u32m1(key1, __riscv_vsetvlmax_e32m1());
1126 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
1127 _philox4x32bumpkey(key0v, key1v);
1128 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
1129 _philox4x32bumpkey(key0v, key1v);
1130 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
1131 _philox4x32bumpkey(key0v, key1v);
1132 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
1133 _philox4x32bumpkey(key0v, key1v);
1134 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
1135 _philox4x32bumpkey(key0v, key1v);
1136 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
1137 _philox4x32bumpkey(key0v, key1v);
1138 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
1139 _philox4x32bumpkey(key0v, key1v);
1140 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
1141 _philox4x32bumpkey(key0v, key1v);
1142 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
1143 _philox4x32bumpkey(key0v, key1v);
1144 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
1145
1146 // convert uint32 to float
1147 rnd1 = __riscv_vfcvt_f_xu_v_f32m1(ctr0, __riscv_vsetvlmax_e32m1());
1148 rnd2 = __riscv_vfcvt_f_xu_v_f32m1(ctr1, __riscv_vsetvlmax_e32m1());
1149 rnd3 = __riscv_vfcvt_f_xu_v_f32m1(ctr2, __riscv_vsetvlmax_e32m1());
1150 rnd4 = __riscv_vfcvt_f_xu_v_f32m1(ctr3, __riscv_vsetvlmax_e32m1());
1151 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
1152 rnd1 = __riscv_vfmadd_vv_f32m1(
1153 rnd1,
1154 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, __riscv_vsetvlmax_e32m1()),
1155 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT / 2.0,
1156 __riscv_vsetvlmax_e32m1()),
1157 __riscv_vsetvlmax_e32m1());
1158 rnd2 = __riscv_vfmadd_vv_f32m1(
1159 rnd2,
1160 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, __riscv_vsetvlmax_e32m1()),
1161 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT / 2.0,
1162 __riscv_vsetvlmax_e32m1()),
1163 __riscv_vsetvlmax_e32m1());
1164 rnd3 = __riscv_vfmadd_vv_f32m1(
1165 rnd3,
1166 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, __riscv_vsetvlmax_e32m1()),
1167 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT / 2.0,
1168 __riscv_vsetvlmax_e32m1()),
1169 __riscv_vsetvlmax_e32m1());
1170 rnd4 = __riscv_vfmadd_vv_f32m1(
1171 rnd4,
1172 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT, __riscv_vsetvlmax_e32m1()),
1173 __riscv_vfmv_v_f_f32m1(TWOPOW32_INV_FLOAT / 2.0,
1174 __riscv_vsetvlmax_e32m1()),
1175 __riscv_vsetvlmax_e32m1());
1176}
1177
1178QUALIFIERS void philox_double2(vuint32m1_t ctr0, vuint32m1_t ctr1,
1179 vuint32m1_t ctr2, vuint32m1_t ctr3, uint32 key0,
1180 uint32 key1, vfloat64m1_t &rnd1lo,
1181 vfloat64m1_t &rnd1hi, vfloat64m1_t &rnd2lo,
1182 vfloat64m1_t &rnd2hi) {
1183 vuint32m1_t key0v = __riscv_vmv_v_x_u32m1(key0, __riscv_vsetvlmax_e32m1());
1184 vuint32m1_t key1v = __riscv_vmv_v_x_u32m1(key1, __riscv_vsetvlmax_e32m1());
1185 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 1
1186 _philox4x32bumpkey(key0v, key1v);
1187 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 2
1188 _philox4x32bumpkey(key0v, key1v);
1189 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 3
1190 _philox4x32bumpkey(key0v, key1v);
1191 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 4
1192 _philox4x32bumpkey(key0v, key1v);
1193 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 5
1194 _philox4x32bumpkey(key0v, key1v);
1195 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 6
1196 _philox4x32bumpkey(key0v, key1v);
1197 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 7
1198 _philox4x32bumpkey(key0v, key1v);
1199 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 8
1200 _philox4x32bumpkey(key0v, key1v);
1201 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 9
1202 _philox4x32bumpkey(key0v, key1v);
1203 _philox4x32round(ctr0, ctr1, ctr2, ctr3, key0v, key1v); // 10
1204
1205 rnd1lo = _uniform_double_hq<false>(ctr0, ctr1);
1206 rnd1hi = _uniform_double_hq<true>(ctr0, ctr1);
1207 rnd2lo = _uniform_double_hq<false>(ctr2, ctr3);
1208 rnd2hi = _uniform_double_hq<true>(ctr2, ctr3);
1209}
1210
1211QUALIFIERS void philox_float4(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2,
1212 uint32 ctr3, uint32 key0, uint32 key1,
1213 vfloat32m1_t &rnd1, vfloat32m1_t &rnd2,
1214 vfloat32m1_t &rnd3, vfloat32m1_t &rnd4) {
1215 vuint32m1_t ctr0v = __riscv_vmv_v_x_u32m1(ctr0, __riscv_vsetvlmax_e32m1());
1216 vuint32m1_t ctr2v = __riscv_vmv_v_x_u32m1(ctr2, __riscv_vsetvlmax_e32m1());
1217 vuint32m1_t ctr3v = __riscv_vmv_v_x_u32m1(ctr3, __riscv_vsetvlmax_e32m1());
1218
1219 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
1220}
1221
1222QUALIFIERS void philox_float4(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2,
1223 uint32 ctr3, uint32 key0, uint32 key1,
1224 vfloat32m1_t &rnd1, vfloat32m1_t &rnd2,
1225 vfloat32m1_t &rnd3, vfloat32m1_t &rnd4) {
1226 philox_float4(ctr0, __riscv_vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3,
1227 key0, key1, rnd1, rnd2, rnd3, rnd4);
1228}
1229
1230QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2,
1231 uint32 ctr3, uint32 key0, uint32 key1,
1232 vfloat64m1_t &rnd1lo, vfloat64m1_t &rnd1hi,
1233 vfloat64m1_t &rnd2lo, vfloat64m1_t &rnd2hi) {
1234 vuint32m1_t ctr0v = __riscv_vmv_v_x_u32m1(ctr0, __riscv_vsetvlmax_e32m1());
1235 vuint32m1_t ctr2v = __riscv_vmv_v_x_u32m1(ctr2, __riscv_vsetvlmax_e32m1());
1236 vuint32m1_t ctr3v = __riscv_vmv_v_x_u32m1(ctr3, __riscv_vsetvlmax_e32m1());
1237
1238 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
1239 rnd2hi);
1240}
1241
1242QUALIFIERS void philox_double2(uint32 ctr0, vuint32m1_t ctr1, uint32 ctr2,
1243 uint32 ctr3, uint32 key0, uint32 key1,
1244 vfloat64m1_t &rnd1, vfloat64m1_t &rnd2) {
1245 vuint32m1_t ctr0v = __riscv_vmv_v_x_u32m1(ctr0, __riscv_vsetvlmax_e32m1());
1246 vuint32m1_t ctr2v = __riscv_vmv_v_x_u32m1(ctr2, __riscv_vsetvlmax_e32m1());
1247 vuint32m1_t ctr3v = __riscv_vmv_v_x_u32m1(ctr3, __riscv_vsetvlmax_e32m1());
1248
1249 vfloat64m1_t ignore;
1250 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2,
1251 ignore);
1252}
1253
1254QUALIFIERS void philox_double2(uint32 ctr0, vint32m1_t ctr1, uint32 ctr2,
1255 uint32 ctr3, uint32 key0, uint32 key1,
1256 vfloat64m1_t &rnd1, vfloat64m1_t &rnd2) {
1257 philox_double2(ctr0, __riscv_vreinterpret_v_i32m1_u32m1(ctr1), ctr2, ctr3,
1258 key0, key1, rnd1, rnd2);
1259}
1260#endif
1261
1262#ifdef __AVX2__
1263QUALIFIERS void _philox4x32round(__m256i *ctr, __m256i *key) {
1264 __m256i lohi0a = _mm256_mul_epu32(ctr[0], _mm256_set1_epi32(PHILOX_M4x32_0));
1265 __m256i lohi0b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[0], 32),
1266 _mm256_set1_epi32(PHILOX_M4x32_0));
1267 __m256i lohi1a = _mm256_mul_epu32(ctr[2], _mm256_set1_epi32(PHILOX_M4x32_1));
1268 __m256i lohi1b = _mm256_mul_epu32(_mm256_srli_epi64(ctr[2], 32),
1269 _mm256_set1_epi32(PHILOX_M4x32_1));
1270
1271 lohi0a = _mm256_shuffle_epi32(lohi0a, 0xD8);
1272 lohi0b = _mm256_shuffle_epi32(lohi0b, 0xD8);
1273 lohi1a = _mm256_shuffle_epi32(lohi1a, 0xD8);
1274 lohi1b = _mm256_shuffle_epi32(lohi1b, 0xD8);
1275
1276 __m256i lo0 = _mm256_unpacklo_epi32(lohi0a, lohi0b);
1277 __m256i hi0 = _mm256_unpackhi_epi32(lohi0a, lohi0b);
1278 __m256i lo1 = _mm256_unpacklo_epi32(lohi1a, lohi1b);
1279 __m256i hi1 = _mm256_unpackhi_epi32(lohi1a, lohi1b);
1280
1281 ctr[0] = _mm256_xor_si256(_mm256_xor_si256(hi1, ctr[1]), key[0]);
1282 ctr[1] = lo1;
1283 ctr[2] = _mm256_xor_si256(_mm256_xor_si256(hi0, ctr[3]), key[1]);
1284 ctr[3] = lo0;
1285}
1286
1287QUALIFIERS void _philox4x32bumpkey(__m256i *key) {
1288 key[0] = _mm256_add_epi32(key[0], _mm256_set1_epi32(PHILOX_W32_0));
1289 key[1] = _mm256_add_epi32(key[1], _mm256_set1_epi32(PHILOX_W32_1));
1290}
1291
1292template <bool high>
1293QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) {
1294 // convert 32 to 64 bit
1295 if (high) {
1296 x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 1));
1297 y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 1));
1298 } else {
1299 x = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(x, 0));
1300 y = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(y, 0));
1301 }
1302
1303 // calculate z = x ^ y << (53 - 32))
1304 __m256i z = _mm256_sll_epi64(y, _mm_set1_epi64x(53 - 32));
1305 z = _mm256_xor_si256(x, z);
1306
1307 // convert uint64 to double
1308 __m256d rs = _my256_cvtepu64_pd(z);
1309 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
1310#ifdef __FMA__
1311 rs = _mm256_fmadd_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE),
1312 _mm256_set1_pd(TWOPOW53_INV_DOUBLE / 2.0));
1313#else
1314 rs = _mm256_mul_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE));
1315 rs = _mm256_add_pd(rs, _mm256_set1_pd(TWOPOW53_INV_DOUBLE / 2.0));
1316#endif
1317
1318 return rs;
1319}
1320
1321QUALIFIERS void philox_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2,
1322 __m256i ctr3, uint32 key0, uint32 key1,
1323 __m256 &rnd1, __m256 &rnd2, __m256 &rnd3,
1324 __m256 &rnd4) {
1325 __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
1326 __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
1327 _philox4x32round(ctr, key); // 1
1328 _philox4x32bumpkey(key);
1329 _philox4x32round(ctr, key); // 2
1330 _philox4x32bumpkey(key);
1331 _philox4x32round(ctr, key); // 3
1332 _philox4x32bumpkey(key);
1333 _philox4x32round(ctr, key); // 4
1334 _philox4x32bumpkey(key);
1335 _philox4x32round(ctr, key); // 5
1336 _philox4x32bumpkey(key);
1337 _philox4x32round(ctr, key); // 6
1338 _philox4x32bumpkey(key);
1339 _philox4x32round(ctr, key); // 7
1340 _philox4x32bumpkey(key);
1341 _philox4x32round(ctr, key); // 8
1342 _philox4x32bumpkey(key);
1343 _philox4x32round(ctr, key); // 9
1344 _philox4x32bumpkey(key);
1345 _philox4x32round(ctr, key); // 10
1346
1347 // convert uint32 to float
1348 rnd1 = _my256_cvtepu32_ps(ctr[0]);
1349 rnd2 = _my256_cvtepu32_ps(ctr[1]);
1350 rnd3 = _my256_cvtepu32_ps(ctr[2]);
1351 rnd4 = _my256_cvtepu32_ps(ctr[3]);
1352 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
1353#ifdef __FMA__
1354 rnd1 = _mm256_fmadd_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT),
1355 _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1356 rnd2 = _mm256_fmadd_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT),
1357 _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1358 rnd3 = _mm256_fmadd_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT),
1359 _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1360 rnd4 = _mm256_fmadd_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT),
1361 _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1362#else
1363 rnd1 = _mm256_mul_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
1364 rnd1 = _mm256_add_ps(rnd1, _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
1365 rnd2 = _mm256_mul_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
1366 rnd2 = _mm256_add_ps(rnd2, _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
1367 rnd3 = _mm256_mul_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
1368 rnd3 = _mm256_add_ps(rnd3, _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
1369 rnd4 = _mm256_mul_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT));
1370 rnd4 = _mm256_add_ps(rnd4, _mm256_set1_ps(TWOPOW32_INV_FLOAT / 2.0f));
1371#endif
1372}
1373
1374QUALIFIERS void philox_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2,
1375 __m256i ctr3, uint32 key0, uint32 key1,
1376 __m256d &rnd1lo, __m256d &rnd1hi,
1377 __m256d &rnd2lo, __m256d &rnd2hi) {
1378 __m256i key[2] = {_mm256_set1_epi32(key0), _mm256_set1_epi32(key1)};
1379 __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
1380 _philox4x32round(ctr, key); // 1
1381 _philox4x32bumpkey(key);
1382 _philox4x32round(ctr, key); // 2
1383 _philox4x32bumpkey(key);
1384 _philox4x32round(ctr, key); // 3
1385 _philox4x32bumpkey(key);
1386 _philox4x32round(ctr, key); // 4
1387 _philox4x32bumpkey(key);
1388 _philox4x32round(ctr, key); // 5
1389 _philox4x32bumpkey(key);
1390 _philox4x32round(ctr, key); // 6
1391 _philox4x32bumpkey(key);
1392 _philox4x32round(ctr, key); // 7
1393 _philox4x32bumpkey(key);
1394 _philox4x32round(ctr, key); // 8
1395 _philox4x32bumpkey(key);
1396 _philox4x32round(ctr, key); // 9
1397 _philox4x32bumpkey(key);
1398 _philox4x32round(ctr, key); // 10
1399
1400 rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
1401 rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
1402 rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
1403 rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
1404}
1405
1406QUALIFIERS void philox_float4(uint32 ctr0, __m256i ctr1, uint32 ctr2,
1407 uint32 ctr3, uint32 key0, uint32 key1,
1408 __m256 &rnd1, __m256 &rnd2, __m256 &rnd3,
1409 __m256 &rnd4) {
1410 __m256i ctr0v = _mm256_set1_epi32(ctr0);
1411 __m256i ctr2v = _mm256_set1_epi32(ctr2);
1412 __m256i ctr3v = _mm256_set1_epi32(ctr3);
1413
1414 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
1415}
1416
1417QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2,
1418 uint32 ctr3, uint32 key0, uint32 key1,
1419 __m256d &rnd1lo, __m256d &rnd1hi,
1420 __m256d &rnd2lo, __m256d &rnd2hi) {
1421 __m256i ctr0v = _mm256_set1_epi32(ctr0);
1422 __m256i ctr2v = _mm256_set1_epi32(ctr2);
1423 __m256i ctr3v = _mm256_set1_epi32(ctr3);
1424
1425 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
1426 rnd2hi);
1427}
1428
1429QUALIFIERS void philox_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2,
1430 uint32 ctr3, uint32 key0, uint32 key1,
1431 __m256d &rnd1, __m256d &rnd2) {
1432#if 0
1433 __m256i ctr0v = _mm256_set1_epi32(ctr0);
1434 __m256i ctr2v = _mm256_set1_epi32(ctr2);
1435 __m256i ctr3v = _mm256_set1_epi32(ctr3);
1436
1437 __m256d ignore;
1438 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
1439#else
1440 __m128d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
1441 philox_double2(ctr0, _mm256_extractf128_si256(ctr1, 0), ctr2, ctr3, key0,
1442 key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
1443 rnd1 = _my256_set_m128d(rnd1hi, rnd1lo);
1444 rnd2 = _my256_set_m128d(rnd2hi, rnd2lo);
1445#endif
1446}
1447#endif
1448
1449#if defined(__AVX512F__) || defined(__AVX10_512BIT__)
1450QUALIFIERS void _philox4x32round(__m512i *ctr, __m512i *key) {
1451 __m512i lohi0a = _mm512_mul_epu32(ctr[0], _mm512_set1_epi32(PHILOX_M4x32_0));
1452 __m512i lohi0b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[0], 32),
1453 _mm512_set1_epi32(PHILOX_M4x32_0));
1454 __m512i lohi1a = _mm512_mul_epu32(ctr[2], _mm512_set1_epi32(PHILOX_M4x32_1));
1455 __m512i lohi1b = _mm512_mul_epu32(_mm512_srli_epi64(ctr[2], 32),
1456 _mm512_set1_epi32(PHILOX_M4x32_1));
1457
1458 lohi0a = _mm512_shuffle_epi32(lohi0a, _MM_PERM_DBCA);
1459 lohi0b = _mm512_shuffle_epi32(lohi0b, _MM_PERM_DBCA);
1460 lohi1a = _mm512_shuffle_epi32(lohi1a, _MM_PERM_DBCA);
1461 lohi1b = _mm512_shuffle_epi32(lohi1b, _MM_PERM_DBCA);
1462
1463 __m512i lo0 = _mm512_unpacklo_epi32(lohi0a, lohi0b);
1464 __m512i hi0 = _mm512_unpackhi_epi32(lohi0a, lohi0b);
1465 __m512i lo1 = _mm512_unpacklo_epi32(lohi1a, lohi1b);
1466 __m512i hi1 = _mm512_unpackhi_epi32(lohi1a, lohi1b);
1467
1468 ctr[0] = _mm512_xor_si512(_mm512_xor_si512(hi1, ctr[1]), key[0]);
1469 ctr[1] = lo1;
1470 ctr[2] = _mm512_xor_si512(_mm512_xor_si512(hi0, ctr[3]), key[1]);
1471 ctr[3] = lo0;
1472}
1473
1474QUALIFIERS void _philox4x32bumpkey(__m512i *key) {
1475 key[0] = _mm512_add_epi32(key[0], _mm512_set1_epi32(PHILOX_W32_0));
1476 key[1] = _mm512_add_epi32(key[1], _mm512_set1_epi32(PHILOX_W32_1));
1477}
1478
1479template <bool high>
1480QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) {
1481 // convert 32 to 64 bit
1482 if (high) {
1483 x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 1));
1484 y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 1));
1485 } else {
1486 x = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(x, 0));
1487 y = _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(y, 0));
1488 }
1489
1490 // calculate z = x ^ y << (53 - 32))
1491 __m512i z = _mm512_sll_epi64(y, _mm_set1_epi64x(53 - 32));
1492 z = _mm512_xor_si512(x, z);
1493
1494 // convert uint64 to double
1495 __m512d rs = _mm512_cvtepu64_pd(z);
1496 // calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
1497 rs = _mm512_fmadd_pd(rs, _mm512_set1_pd(TWOPOW53_INV_DOUBLE),
1498 _mm512_set1_pd(TWOPOW53_INV_DOUBLE / 2.0));
1499
1500 return rs;
1501}
1502
1503QUALIFIERS void philox_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2,
1504 __m512i ctr3, uint32 key0, uint32 key1,
1505 __m512 &rnd1, __m512 &rnd2, __m512 &rnd3,
1506 __m512 &rnd4) {
1507 __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
1508 __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
1509 _philox4x32round(ctr, key); // 1
1510 _philox4x32bumpkey(key);
1511 _philox4x32round(ctr, key); // 2
1512 _philox4x32bumpkey(key);
1513 _philox4x32round(ctr, key); // 3
1514 _philox4x32bumpkey(key);
1515 _philox4x32round(ctr, key); // 4
1516 _philox4x32bumpkey(key);
1517 _philox4x32round(ctr, key); // 5
1518 _philox4x32bumpkey(key);
1519 _philox4x32round(ctr, key); // 6
1520 _philox4x32bumpkey(key);
1521 _philox4x32round(ctr, key); // 7
1522 _philox4x32bumpkey(key);
1523 _philox4x32round(ctr, key); // 8
1524 _philox4x32bumpkey(key);
1525 _philox4x32round(ctr, key); // 9
1526 _philox4x32bumpkey(key);
1527 _philox4x32round(ctr, key); // 10
1528
1529 // convert uint32 to float
1530 rnd1 = _mm512_cvtepu32_ps(ctr[0]);
1531 rnd2 = _mm512_cvtepu32_ps(ctr[1]);
1532 rnd3 = _mm512_cvtepu32_ps(ctr[2]);
1533 rnd4 = _mm512_cvtepu32_ps(ctr[3]);
1534 // calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
1535 rnd1 = _mm512_fmadd_ps(rnd1, _mm512_set1_ps(TWOPOW32_INV_FLOAT),
1536 _mm512_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1537 rnd2 = _mm512_fmadd_ps(rnd2, _mm512_set1_ps(TWOPOW32_INV_FLOAT),
1538 _mm512_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1539 rnd3 = _mm512_fmadd_ps(rnd3, _mm512_set1_ps(TWOPOW32_INV_FLOAT),
1540 _mm512_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1541 rnd4 = _mm512_fmadd_ps(rnd4, _mm512_set1_ps(TWOPOW32_INV_FLOAT),
1542 _mm512_set1_ps(TWOPOW32_INV_FLOAT / 2.0));
1543}
1544
1545QUALIFIERS void philox_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2,
1546 __m512i ctr3, uint32 key0, uint32 key1,
1547 __m512d &rnd1lo, __m512d &rnd1hi,
1548 __m512d &rnd2lo, __m512d &rnd2hi) {
1549 __m512i key[2] = {_mm512_set1_epi32(key0), _mm512_set1_epi32(key1)};
1550 __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
1551 _philox4x32round(ctr, key); // 1
1552 _philox4x32bumpkey(key);
1553 _philox4x32round(ctr, key); // 2
1554 _philox4x32bumpkey(key);
1555 _philox4x32round(ctr, key); // 3
1556 _philox4x32bumpkey(key);
1557 _philox4x32round(ctr, key); // 4
1558 _philox4x32bumpkey(key);
1559 _philox4x32round(ctr, key); // 5
1560 _philox4x32bumpkey(key);
1561 _philox4x32round(ctr, key); // 6
1562 _philox4x32bumpkey(key);
1563 _philox4x32round(ctr, key); // 7
1564 _philox4x32bumpkey(key);
1565 _philox4x32round(ctr, key); // 8
1566 _philox4x32bumpkey(key);
1567 _philox4x32round(ctr, key); // 9
1568 _philox4x32bumpkey(key);
1569 _philox4x32round(ctr, key); // 10
1570
1571 rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
1572 rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
1573 rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
1574 rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
1575}
1576
1577QUALIFIERS void philox_float4(uint32 ctr0, __m512i ctr1, uint32 ctr2,
1578 uint32 ctr3, uint32 key0, uint32 key1,
1579 __m512 &rnd1, __m512 &rnd2, __m512 &rnd3,
1580 __m512 &rnd4) {
1581 __m512i ctr0v = _mm512_set1_epi32(ctr0);
1582 __m512i ctr2v = _mm512_set1_epi32(ctr2);
1583 __m512i ctr3v = _mm512_set1_epi32(ctr3);
1584
1585 philox_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, rnd2, rnd3, rnd4);
1586}
1587
1588QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2,
1589 uint32 ctr3, uint32 key0, uint32 key1,
1590 __m512d &rnd1lo, __m512d &rnd1hi,
1591 __m512d &rnd2lo, __m512d &rnd2hi) {
1592 __m512i ctr0v = _mm512_set1_epi32(ctr0);
1593 __m512i ctr2v = _mm512_set1_epi32(ctr2);
1594 __m512i ctr3v = _mm512_set1_epi32(ctr3);
1595
1596 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1lo, rnd1hi, rnd2lo,
1597 rnd2hi);
1598}
1599
1600QUALIFIERS void philox_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2,
1601 uint32 ctr3, uint32 key0, uint32 key1,
1602 __m512d &rnd1, __m512d &rnd2) {
1603#if 0
1604 __m512i ctr0v = _mm512_set1_epi32(ctr0);
1605 __m512i ctr2v = _mm512_set1_epi32(ctr2);
1606 __m512i ctr3v = _mm512_set1_epi32(ctr3);
1607
1608 __m512d ignore;
1609 philox_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, rnd1, ignore, rnd2, ignore);
1610#else
1611 __m256d rnd1lo, rnd1hi, rnd2lo, rnd2hi;
1612 philox_double2(ctr0, _mm512_extracti64x4_epi64(ctr1, 0), ctr2, ctr3, key0,
1613 key1, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
1614 rnd1 = _my512_set_m256d(rnd1hi, rnd1lo);
1615 rnd2 = _my512_set_m256d(rnd2hi, rnd2lo);
1616#endif
1617}
1618#endif
1619#endif
1620
1621#undef QUALIFIERS
1622#undef SVE_QUALIFIERS
1623#undef PHILOX_W32_0
1624#undef PHILOX_W32_1
1625#undef PHILOX_M4x32_0
1626#undef PHILOX_M4x32_1
1627#undef TWOPOW53_INV_DOUBLE
1628#undef TWOPOW32_INV_FLOAT
QUALIFIERS void _philox4x32bumpkey(uint32 *key)
QUALIFIERS void _philox4x32round(uint32 *ctr, uint32 *key)
QUALIFIERS uint32 mulhilo32(uint32 a, uint32 b, uint32 *hip)
QUALIFIERS void philox_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, double &rnd1, double &rnd2)
QUALIFIERS void philox_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3, uint32 key0, uint32 key1, float &rnd1, float &rnd2, float &rnd3, float &rnd4)
QUALIFIERS double _uniform_double_hq(uint32 x, uint32 y)
Philox counter-based RNG utility functions.