OpenJPH
Open-source implementation of JPEG2000 Part-15
Loading...
Searching...
No Matches
ojph_transform_avx512.cpp
Go to the documentation of this file.
1//***************************************************************************/
2// This software is released under the 2-Clause BSD license, included
3// below.
4//
5// Copyright (c) 2019-2024, Aous Naman
6// Copyright (c) 2019-2024, Kakadu Software Pty Ltd, Australia
7// Copyright (c) 2019-2024, The University of New South Wales, Australia
8//
9// Redistribution and use in source and binary forms, with or without
10// modification, are permitted provided that the following conditions are
11// met:
12//
13// 1. Redistributions of source code must retain the above copyright
14// notice, this list of conditions and the following disclaimer.
15//
16// 2. Redistributions in binary form must reproduce the above copyright
17// notice, this list of conditions and the following disclaimer in the
18// documentation and/or other materials provided with the distribution.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
26// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
27// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
28// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
29// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31//***************************************************************************/
32// This file is part of the OpenJPH software implementation.
33// File: ojph_transform_avx512.cpp
34// Author: Aous Naman
35// Date: 13 April 2024
36//***************************************************************************/
37
38#include "ojph_arch.h"
39#if defined(OJPH_ARCH_X86_64)
40
41#include <cstdio>
42
43#include "ojph_defs.h"
44#include "ojph_mem.h"
45#include "ojph_params.h"
47
48#include "ojph_transform.h"
50
51#include <immintrin.h>
52
53namespace ojph {
54 namespace local {
55
57 // We split multiples of 32 followed by multiples of 16, because
58 // we assume byte_alignment == 64
59 static
60 void avx512_deinterleave32(float* dpl, float* dph, float* sp, int width)
61 {
62 __m512i idx1 = _mm512_set_epi32(
63 0x1E, 0x1C, 0x1A, 0x18, 0x16, 0x14, 0x12, 0x10,
64 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
65 );
66 __m512i idx2 = _mm512_set_epi32(
67 0x1F, 0x1D, 0x1B, 0x19, 0x17, 0x15, 0x13, 0x11,
68 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
69 );
70 for (; width > 16; width -= 32, sp += 32, dpl += 16, dph += 16)
71 {
72 __m512 a = _mm512_load_ps(sp);
73 __m512 b = _mm512_load_ps(sp + 16);
74 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
75 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
76 _mm512_store_ps(dpl, c);
77 _mm512_store_ps(dph, d);
78 }
79 for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8)
80 {
81 __m256 a = _mm256_load_ps(sp);
82 __m256 b = _mm256_load_ps(sp + 8);
83 __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0));
84 __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1));
85 __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0));
86 __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1));
87 _mm256_store_ps(dpl, e);
88 _mm256_store_ps(dph, f);
89 }
90 }
91
93 // We split multiples of 32 followed by multiples of 16, because
94 // we assume byte_alignment == 64
95 static
96 void avx512_interleave32(float* dp, float* spl, float* sph, int width)
97 {
98 __m512i idx1 = _mm512_set_epi32(
99 0x17, 0x7, 0x16, 0x6, 0x15, 0x5, 0x14, 0x4,
100 0x13, 0x3, 0x12, 0x2, 0x11, 0x1, 0x10, 0x0
101 );
102 __m512i idx2 = _mm512_set_epi32(
103 0x1F, 0xF, 0x1E, 0xE, 0x1D, 0xD, 0x1C, 0xC,
104 0x1B, 0xB, 0x1A, 0xA, 0x19, 0x9, 0x18, 0x8
105 );
106 for (; width > 16; width -= 32, dp += 32, spl += 16, sph += 16)
107 {
108 __m512 a = _mm512_load_ps(spl);
109 __m512 b = _mm512_load_ps(sph);
110 __m512 c = _mm512_permutex2var_ps(a, idx1, b);
111 __m512 d = _mm512_permutex2var_ps(a, idx2, b);
112 _mm512_store_ps(dp, c);
113 _mm512_store_ps(dp + 16, d);
114 }
115 for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8)
116 {
117 __m256 a = _mm256_load_ps(spl);
118 __m256 b = _mm256_load_ps(sph);
119 __m256 c = _mm256_unpacklo_ps(a, b);
120 __m256 d = _mm256_unpackhi_ps(a, b);
121 __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0));
122 __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1));
123 _mm256_store_ps(dp, e);
124 _mm256_store_ps(dp + 8, f);
125 }
126 }
127
129 // We split multiples of 32 followed by multiples of 16, because
130 // we assume byte_alignment == 64
131 static void avx512_deinterleave64(void* dpl, void* dph, const void* sp,
132 int width)
133 {
134 __m512i idx1 = _mm512_set_epi64(
135 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00
136 );
137 __m512i idx2 = _mm512_set_epi64(
138 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01
139 );
140 for (; width > 8; width -= 16,
141 sp = (const char*)sp + 128,
142 dpl = (char*)dpl + 64,
143 dph = (char*)dph + 64)
144 {
145 __m512i a = _mm512_load_si512(sp);
146 __m512i b = _mm512_load_si512((const char*)sp + 64);
147 __m512i c = _mm512_permutex2var_epi64(a, idx1, b);
148 __m512i d = _mm512_permutex2var_epi64(a, idx2, b);
149 _mm512_store_si512(dpl, c);
150 _mm512_store_si512(dph, d);
151 }
152 for (; width > 0; width -= 8,
153 sp = (const char*)sp + 64,
154 dpl = (char*)dpl + 32,
155 dph = (char*)dph + 32)
156 {
157 __m256i a = _mm256_load_si256((const __m256i*)sp);
158 __m256i b = _mm256_load_si256((const __m256i*)((const char*)sp + 32));
159 __m256i c = _mm256_permute2f128_si256(a, b, (2 << 4) | (0));
160 __m256i d = _mm256_permute2f128_si256(a, b, (3 << 4) | (1));
161 __m256i e = _mm256_unpacklo_epi64(c, d);
162 __m256i f = _mm256_unpackhi_epi64(c, d);
163 _mm256_store_si256((__m256i*)dpl, e);
164 _mm256_store_si256((__m256i*)dph, f);
165 }
166 }
167
169 // We split multiples of 32 followed by multiples of 16, because
170 // we assume byte_alignment == 64
171 static void avx512_interleave64(void* dp, const void* spl,
172 const void* sph, int width)
173 {
174 __m512i idx1 = _mm512_set_epi64(
175 0xB, 0x3, 0xA, 0x2, 0x9, 0x1, 0x8, 0x0
176 );
177 __m512i idx2 = _mm512_set_epi64(
178 0xF, 0x7, 0xE, 0x6, 0xD, 0x5, 0xC, 0x4
179 );
180 for (; width > 8; width -= 16,
181 dp = (char*)dp + 128,
182 spl = (const char*)spl + 64,
183 sph = (const char*)sph + 64)
184 {
185 __m512i a = _mm512_load_si512(spl);
186 __m512i b = _mm512_load_si512(sph);
187 __m512i c = _mm512_permutex2var_epi64(a, idx1, b);
188 __m512i d = _mm512_permutex2var_epi64(a, idx2, b);
189 _mm512_store_si512(dp, c);
190 _mm512_store_si512((char*)dp + 64, d);
191 }
192 for (; width > 0; width -= 8,
193 dp = (char*)dp + 64,
194 spl = (const char*)spl + 32,
195 sph = (const char*)sph + 32)
196 {
197 __m256i a = _mm256_load_si256((const __m256i*)spl);
198 __m256i b = _mm256_load_si256((const __m256i*)sph);
199 __m256i c = _mm256_unpacklo_epi64(a, b);
200 __m256i d = _mm256_unpackhi_epi64(a, b);
201 __m256i e = _mm256_permute2f128_si256(c, d, (2 << 4) | (0));
202 __m256i f = _mm256_permute2f128_si256(c, d, (3 << 4) | (1));
203 _mm256_store_si256((__m256i*)dp, e);
204 _mm256_store_si256((__m256i*)((char*)dp + 32), f);
205 }
206 }
207
209 static inline void avx512_multiply_const(float* p, float f, int width)
210 {
211 __m512 factor = _mm512_set1_ps(f);
212 for (; width > 0; width -= 16, p += 16)
213 {
214 __m512 s = _mm512_load_ps(p);
215 _mm512_store_ps(p, _mm512_mul_ps(factor, s));
216 }
217 }
218
220 void avx512_irv_vert_step(const lifting_step* s, const line_buf* sig,
221 const line_buf* other, const line_buf* aug,
222 ui32 repeat, bool synthesis)
223 {
224 float a = s->irv.Aatk;
225 if (synthesis)
226 a = -a;
227
228 __m512 factor = _mm512_set1_ps(a);
229
230 float* dst = aug->f32;
231 const float* src1 = sig->f32, * src2 = other->f32;
232 int i = (int)repeat;
233 for ( ; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
234 {
235 __m512 s1 = _mm512_load_ps(src1);
236 __m512 s2 = _mm512_load_ps(src2);
237 __m512 d = _mm512_load_ps(dst);
238 d = _mm512_add_ps(d, _mm512_mul_ps(factor, _mm512_add_ps(s1, s2)));
239 _mm512_store_ps(dst, d);
240 }
241 }
242
244 void avx512_irv_vert_times_K(float K, const line_buf* aug, ui32 repeat)
245 {
246 avx512_multiply_const(aug->f32, K, (int)repeat);
247 }
248
250 void avx512_irv_horz_ana(const param_atk* atk, const line_buf* ldst,
251 const line_buf* hdst, const line_buf* src,
252 ui32 width, bool even)
253 {
254 if (width > 1)
255 {
256 // split src into ldst and hdst
257 {
258 float* dpl = even ? ldst->f32 : hdst->f32;
259 float* dph = even ? hdst->f32 : ldst->f32;
260 float* sp = src->f32;
261 int w = (int)width;
262 avx512_deinterleave32(dpl, dph, sp, w);
263 }
264
265 // the actual horizontal transform
266 float* hp = hdst->f32, * lp = ldst->f32;
267 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
268 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
269 ui32 num_steps = atk->get_num_steps();
270 for (ui32 j = num_steps; j > 0; --j)
271 {
272 const lifting_step* s = atk->get_step(j - 1);
273 const float a = s->irv.Aatk;
274
275 // extension
276 lp[-1] = lp[0];
277 lp[l_width] = lp[l_width - 1];
278 // lifting step
279 const float* sp = lp;
280 float* dp = hp;
281 int i = (int)h_width;
282 __m512 f = _mm512_set1_ps(a);
283 if (even)
284 {
285 for (; i > 0; i -= 16, sp += 16, dp += 16)
286 {
287 __m512 m = _mm512_load_ps(sp);
288 __m512 n = _mm512_loadu_ps(sp + 1);
289 __m512 p = _mm512_load_ps(dp);
290 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
291 _mm512_store_ps(dp, p);
292 }
293 }
294 else
295 {
296 for (; i > 0; i -= 16, sp += 16, dp += 16)
297 {
298 __m512 m = _mm512_load_ps(sp);
299 __m512 n = _mm512_loadu_ps(sp - 1);
300 __m512 p = _mm512_load_ps(dp);
301 p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
302 _mm512_store_ps(dp, p);
303 }
304 }
305
306 // swap buffers
307 float* t = lp; lp = hp; hp = t;
308 even = !even;
309 ui32 w = l_width; l_width = h_width; h_width = w;
310 }
311
312 { // multiply by K or 1/K
313 float K = atk->get_K();
314 float K_inv = 1.0f / K;
315 avx512_multiply_const(lp, K_inv, (int)l_width);
316 avx512_multiply_const(hp, K, (int)h_width);
317 }
318 }
319 else {
320 if (even)
321 ldst->f32[0] = src->f32[0];
322 else
323 hdst->f32[0] = src->f32[0] * 2.0f;
324 }
325 }
326
328 void avx512_irv_horz_syn(const param_atk* atk, const line_buf* dst,
329 const line_buf* lsrc, const line_buf* hsrc,
330 ui32 width, bool even)
331 {
332 if (width > 1)
333 {
334 bool ev = even;
335 float* oth = hsrc->f32, * aug = lsrc->f32;
336 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
337 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
338
339 { // multiply by K or 1/K
340 float K = atk->get_K();
341 float K_inv = 1.0f / K;
342 avx512_multiply_const(aug, K, (int)aug_width);
343 avx512_multiply_const(oth, K_inv, (int)oth_width);
344 }
345
346 // the actual horizontal transform
347 ui32 num_steps = atk->get_num_steps();
348 for (ui32 j = 0; j < num_steps; ++j)
349 {
350 const lifting_step* s = atk->get_step(j);
351 const float a = s->irv.Aatk;
352
353 // extension
354 oth[-1] = oth[0];
355 oth[oth_width] = oth[oth_width - 1];
356 // lifting step
357 const float* sp = oth;
358 float* dp = aug;
359 int i = (int)aug_width;
360 __m512 f = _mm512_set1_ps(a);
361 if (ev)
362 {
363 for (; i > 0; i -= 16, sp += 16, dp += 16)
364 {
365 __m512 m = _mm512_load_ps(sp);
366 __m512 n = _mm512_loadu_ps(sp - 1);
367 __m512 p = _mm512_load_ps(dp);
368 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
369 _mm512_store_ps(dp, p);
370 }
371 }
372 else
373 {
374 for (; i > 0; i -= 16, sp += 16, dp += 16)
375 {
376 __m512 m = _mm512_load_ps(sp);
377 __m512 n = _mm512_loadu_ps(sp + 1);
378 __m512 p = _mm512_load_ps(dp);
379 p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n)));
380 _mm512_store_ps(dp, p);
381 }
382 }
383
384 // swap buffers
385 float* t = aug; aug = oth; oth = t;
386 ev = !ev;
387 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
388 }
389
390 // combine both lsrc and hsrc into dst
391 {
392 float* dp = dst->f32;
393 float* spl = even ? lsrc->f32 : hsrc->f32;
394 float* sph = even ? hsrc->f32 : lsrc->f32;
395 int w = (int)width;
396 avx512_interleave32(dp, spl, sph, w);
397 }
398 }
399 else {
400 if (even)
401 dst->f32[0] = lsrc->f32[0];
402 else
403 dst->f32[0] = hsrc->f32[0] * 0.5f;
404 }
405 }
406
407
409 void avx512_rev_vert_step32(const lifting_step* s, const line_buf* sig,
410 const line_buf* other, const line_buf* aug,
411 ui32 repeat, bool synthesis)
412 {
413 const si32 a = s->rev.Aatk;
414 const si32 b = s->rev.Batk;
415 const ui8 e = s->rev.Eatk;
416 __m512i va = _mm512_set1_epi32(a);
417 __m512i vb = _mm512_set1_epi32(b);
418
419 si32* dst = aug->i32;
420 const si32* src1 = sig->i32, * src2 = other->i32;
421 // The general definition of the wavelet in Part 2 is slightly
422 // different to part 2, although they are mathematically equivalent
423 // here, we identify the simpler form from Part 1 and employ them
424 if (a == 1)
425 { // 5/3 update and any case with a == 1
426 int i = (int)repeat;
427 if (synthesis)
428 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
429 {
430 __m512i s1 = _mm512_load_si512((__m512i*)src1);
431 __m512i s2 = _mm512_load_si512((__m512i*)src2);
432 __m512i d = _mm512_load_si512((__m512i*)dst);
433 __m512i t = _mm512_add_epi32(s1, s2);
434 __m512i v = _mm512_add_epi32(vb, t);
435 __m512i w = _mm512_srai_epi32(v, e);
436 d = _mm512_sub_epi32(d, w);
437 _mm512_store_si512((__m512i*)dst, d);
438 }
439 else
440 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
441 {
442 __m512i s1 = _mm512_load_si512((__m512i*)src1);
443 __m512i s2 = _mm512_load_si512((__m512i*)src2);
444 __m512i d = _mm512_load_si512((__m512i*)dst);
445 __m512i t = _mm512_add_epi32(s1, s2);
446 __m512i v = _mm512_add_epi32(vb, t);
447 __m512i w = _mm512_srai_epi32(v, e);
448 d = _mm512_add_epi32(d, w);
449 _mm512_store_si512((__m512i*)dst, d);
450 }
451 }
452 else if (a == -1 && b == 1 && e == 1)
453 { // 5/3 predict
454 int i = (int)repeat;
455 if (synthesis)
456 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
457 {
458 __m512i s1 = _mm512_load_si512((__m512i*)src1);
459 __m512i s2 = _mm512_load_si512((__m512i*)src2);
460 __m512i d = _mm512_load_si512((__m512i*)dst);
461 __m512i t = _mm512_add_epi32(s1, s2);
462 __m512i w = _mm512_srai_epi32(t, e);
463 d = _mm512_add_epi32(d, w);
464 _mm512_store_si512((__m512i*)dst, d);
465 }
466 else
467 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
468 {
469 __m512i s1 = _mm512_load_si512((__m512i*)src1);
470 __m512i s2 = _mm512_load_si512((__m512i*)src2);
471 __m512i d = _mm512_load_si512((__m512i*)dst);
472 __m512i t = _mm512_add_epi32(s1, s2);
473 __m512i w = _mm512_srai_epi32(t, e);
474 d = _mm512_sub_epi32(d, w);
475 _mm512_store_si512((__m512i*)dst, d);
476 }
477 }
478 else if (a == -1)
479 { // any case with a == -1, which is not 5/3 predict
480 int i = (int)repeat;
481 if (synthesis)
482 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
483 {
484 __m512i s1 = _mm512_load_si512((__m512i*)src1);
485 __m512i s2 = _mm512_load_si512((__m512i*)src2);
486 __m512i d = _mm512_load_si512((__m512i*)dst);
487 __m512i t = _mm512_add_epi32(s1, s2);
488 __m512i v = _mm512_sub_epi32(vb, t);
489 __m512i w = _mm512_srai_epi32(v, e);
490 d = _mm512_sub_epi32(d, w);
491 _mm512_store_si512((__m512i*)dst, d);
492 }
493 else
494 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
495 {
496 __m512i s1 = _mm512_load_si512((__m512i*)src1);
497 __m512i s2 = _mm512_load_si512((__m512i*)src2);
498 __m512i d = _mm512_load_si512((__m512i*)dst);
499 __m512i t = _mm512_add_epi32(s1, s2);
500 __m512i v = _mm512_sub_epi32(vb, t);
501 __m512i w = _mm512_srai_epi32(v, e);
502 d = _mm512_add_epi32(d, w);
503 _mm512_store_si512((__m512i*)dst, d);
504 }
505 }
506 else { // general case
507 int i = (int)repeat;
508 if (synthesis)
509 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
510 {
511 __m512i s1 = _mm512_load_si512((__m512i*)src1);
512 __m512i s2 = _mm512_load_si512((__m512i*)src2);
513 __m512i d = _mm512_load_si512((__m512i*)dst);
514 __m512i t = _mm512_add_epi32(s1, s2);
515 __m512i u = _mm512_mullo_epi32(va, t);
516 __m512i v = _mm512_add_epi32(vb, u);
517 __m512i w = _mm512_srai_epi32(v, e);
518 d = _mm512_sub_epi32(d, w);
519 _mm512_store_si512((__m512i*)dst, d);
520 }
521 else
522 for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16)
523 {
524 __m512i s1 = _mm512_load_si512((__m512i*)src1);
525 __m512i s2 = _mm512_load_si512((__m512i*)src2);
526 __m512i d = _mm512_load_si512((__m512i*)dst);
527 __m512i t = _mm512_add_epi32(s1, s2);
528 __m512i u = _mm512_mullo_epi32(va, t);
529 __m512i v = _mm512_add_epi32(vb, u);
530 __m512i w = _mm512_srai_epi32(v, e);
531 d = _mm512_add_epi32(d, w);
532 _mm512_store_si512((__m512i*)dst, d);
533 }
534 }
535 }
536
538 void avx512_rev_vert_step64(const lifting_step* s, const line_buf* sig,
539 const line_buf* other, const line_buf* aug,
540 ui32 repeat, bool synthesis)
541 {
542 const si32 a = s->rev.Aatk;
543 const si32 b = s->rev.Batk;
544 const ui8 e = s->rev.Eatk;
545 __m512i vb = _mm512_set1_epi64(b);
546
547 si64* dst = aug->i64;
548 const si64* src1 = sig->i64, * src2 = other->i64;
549 // The general definition of the wavelet in Part 2 is slightly
550 // different to part 2, although they are mathematically equivalent
551 // here, we identify the simpler form from Part 1 and employ them
552 if (a == 1)
553 { // 5/3 update and any case with a == 1
554 int i = (int)repeat;
555 if (synthesis)
556 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
557 {
558 __m512i s1 = _mm512_load_si512((__m512i*)src1);
559 __m512i s2 = _mm512_load_si512((__m512i*)src2);
560 __m512i d = _mm512_load_si512((__m512i*)dst);
561 __m512i t = _mm512_add_epi64(s1, s2);
562 __m512i v = _mm512_add_epi64(vb, t);
563 __m512i w = _mm512_srai_epi64(v, e);
564 d = _mm512_sub_epi64(d, w);
565 _mm512_store_si512((__m512i*)dst, d);
566 }
567 else
568 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
569 {
570 __m512i s1 = _mm512_load_si512((__m512i*)src1);
571 __m512i s2 = _mm512_load_si512((__m512i*)src2);
572 __m512i d = _mm512_load_si512((__m512i*)dst);
573 __m512i t = _mm512_add_epi64(s1, s2);
574 __m512i v = _mm512_add_epi64(vb, t);
575 __m512i w = _mm512_srai_epi64(v, e);
576 d = _mm512_add_epi64(d, w);
577 _mm512_store_si512((__m512i*)dst, d);
578 }
579 }
580 else if (a == -1 && b == 1 && e == 1)
581 { // 5/3 predict
582 int i = (int)repeat;
583 if (synthesis)
584 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
585 {
586 __m512i s1 = _mm512_load_si512((__m512i*)src1);
587 __m512i s2 = _mm512_load_si512((__m512i*)src2);
588 __m512i d = _mm512_load_si512((__m512i*)dst);
589 __m512i t = _mm512_add_epi64(s1, s2);
590 __m512i w = _mm512_srai_epi64(t, e);
591 d = _mm512_add_epi64(d, w);
592 _mm512_store_si512((__m512i*)dst, d);
593 }
594 else
595 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
596 {
597 __m512i s1 = _mm512_load_si512((__m512i*)src1);
598 __m512i s2 = _mm512_load_si512((__m512i*)src2);
599 __m512i d = _mm512_load_si512((__m512i*)dst);
600 __m512i t = _mm512_add_epi64(s1, s2);
601 __m512i w = _mm512_srai_epi64(t, e);
602 d = _mm512_sub_epi64(d, w);
603 _mm512_store_si512((__m512i*)dst, d);
604 }
605 }
606 else if (a == -1)
607 { // any case with a == -1, which is not 5/3 predict
608 int i = (int)repeat;
609 if (synthesis)
610 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
611 {
612 __m512i s1 = _mm512_load_si512((__m512i*)src1);
613 __m512i s2 = _mm512_load_si512((__m512i*)src2);
614 __m512i d = _mm512_load_si512((__m512i*)dst);
615 __m512i t = _mm512_add_epi64(s1, s2);
616 __m512i v = _mm512_sub_epi64(vb, t);
617 __m512i w = _mm512_srai_epi64(v, e);
618 d = _mm512_sub_epi64(d, w);
619 _mm512_store_si512((__m512i*)dst, d);
620 }
621 else
622 for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
623 {
624 __m512i s1 = _mm512_load_si512((__m512i*)src1);
625 __m512i s2 = _mm512_load_si512((__m512i*)src2);
626 __m512i d = _mm512_load_si512((__m512i*)dst);
627 __m512i t = _mm512_add_epi64(s1, s2);
628 __m512i v = _mm512_sub_epi64(vb, t);
629 __m512i w = _mm512_srai_epi64(v, e);
630 d = _mm512_add_epi64(d, w);
631 _mm512_store_si512((__m512i*)dst, d);
632 }
633 }
634 else {
635 // general case
636 // 64bit multiplication is not supported in AVX512F + AVX512CD;
637 // in particular, _mm256_mullo_epi64.
638 if (synthesis)
639 for (ui32 i = repeat; i > 0; --i)
640 *dst++ -= (b + a * (*src1++ + *src2++)) >> e;
641 else
642 for (ui32 i = repeat; i > 0; --i)
643 *dst++ += (b + a * (*src1++ + *src2++)) >> e;
644 }
645
646 // This can only be used if you have AVX512DQ
647 // { // general case
648 // __m512i va = _mm512_set1_epi64(a);
649 // int i = (int)repeat;
650 // if (synthesis)
651 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
652 // {
653 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
654 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
655 // __m512i d = _mm512_load_si512((__m512i*)dst);
656 // __m512i t = _mm512_add_epi64(s1, s2);
657 // __m512i u = _mm512_mullo_epi64(va, t);
658 // __m512i v = _mm512_add_epi64(vb, u);
659 // __m512i w = _mm512_srai_epi64(v, e);
660 // d = _mm512_sub_epi64(d, w);
661 // _mm512_store_si512((__m512i*)dst, d);
662 // }
663 // else
664 // for (; i > 0; i -= 8, dst += 8, src1 += 8, src2 += 8)
665 // {
666 // __m512i s1 = _mm512_load_si512((__m512i*)src1);
667 // __m512i s2 = _mm512_load_si512((__m512i*)src2);
668 // __m512i d = _mm512_load_si512((__m512i*)dst);
669 // __m512i t = _mm512_add_epi64(s1, s2);
670 // __m512i u = _mm512_mullo_epi64(va, t);
671 // __m512i v = _mm512_add_epi64(vb, u);
672 // __m512i w = _mm512_srai_epi64(v, e);
673 // d = _mm512_add_epi64(d, w);
674 // _mm512_store_si512((__m512i*)dst, d);
675 // }
676 // }
677 }
678
680 void avx512_rev_vert_step(const lifting_step* s, const line_buf* sig,
681 const line_buf* other, const line_buf* aug,
682 ui32 repeat, bool synthesis)
683 {
684 if (((sig != NULL) && (sig->flags & line_buf::LFT_32BIT)) ||
685 ((aug != NULL) && (aug->flags & line_buf::LFT_32BIT)) ||
686 ((other != NULL) && (other->flags & line_buf::LFT_32BIT)))
687 {
688 assert((sig == NULL || sig->flags & line_buf::LFT_32BIT) &&
689 (other == NULL || other->flags & line_buf::LFT_32BIT) &&
690 (aug == NULL || aug->flags & line_buf::LFT_32BIT));
691 avx512_rev_vert_step32(s, sig, other, aug, repeat, synthesis);
692 }
693 else
694 {
695 assert((sig == NULL || sig->flags & line_buf::LFT_64BIT) &&
696 (other == NULL || other->flags & line_buf::LFT_64BIT) &&
697 (aug == NULL || aug->flags & line_buf::LFT_64BIT));
698 avx512_rev_vert_step64(s, sig, other, aug, repeat, synthesis);
699 }
700 }
701
703 void avx512_rev_horz_ana32(const param_atk* atk, const line_buf* ldst,
704 const line_buf* hdst, const line_buf* src,
705 ui32 width, bool even)
706 {
707 if (width > 1)
708 {
709 // split src into ldst and hdst
710 {
711 float* dpl = even ? ldst->f32 : hdst->f32;
712 float* dph = even ? hdst->f32 : ldst->f32;
713 float* sp = src->f32;
714 int w = (int)width;
715 avx512_deinterleave32(dpl, dph, sp, w);
716 }
717
718 si32* hp = hdst->i32, * lp = ldst->i32;
719 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
720 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
721 ui32 num_steps = atk->get_num_steps();
722 for (ui32 j = num_steps; j > 0; --j)
723 {
724 // first lifting step
725 const lifting_step* s = atk->get_step(j - 1);
726 const si32 a = s->rev.Aatk;
727 const si32 b = s->rev.Batk;
728 const ui8 e = s->rev.Eatk;
729 __m512i va = _mm512_set1_epi32(a);
730 __m512i vb = _mm512_set1_epi32(b);
731
732 // extension
733 lp[-1] = lp[0];
734 lp[l_width] = lp[l_width - 1];
735 // lifting step
736 const si32* sp = lp;
737 si32* dp = hp;
738 if (a == 1)
739 { // 5/3 update and any case with a == 1
740 int i = (int)h_width;
741 if (even)
742 {
743 for (; i > 0; i -= 16, sp += 16, dp += 16)
744 {
745 __m512i s1 = _mm512_load_si512((__m512i*)sp);
746 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
747 __m512i d = _mm512_load_si512((__m512i*)dp);
748 __m512i t = _mm512_add_epi32(s1, s2);
749 __m512i v = _mm512_add_epi32(vb, t);
750 __m512i w = _mm512_srai_epi32(v, e);
751 d = _mm512_add_epi32(d, w);
752 _mm512_store_si512((__m512i*)dp, d);
753 }
754 }
755 else
756 {
757 for (; i > 0; i -= 16, sp += 16, dp += 16)
758 {
759 __m512i s1 = _mm512_load_si512((__m512i*)sp);
760 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
761 __m512i d = _mm512_load_si512((__m512i*)dp);
762 __m512i t = _mm512_add_epi32(s1, s2);
763 __m512i v = _mm512_add_epi32(vb, t);
764 __m512i w = _mm512_srai_epi32(v, e);
765 d = _mm512_add_epi32(d, w);
766 _mm512_store_si512((__m512i*)dp, d);
767 }
768 }
769 }
770 else if (a == -1 && b == 1 && e == 1)
771 { // 5/3 predict
772 int i = (int)h_width;
773 if (even)
774 for (; i > 0; i -= 16, sp += 16, dp += 16)
775 {
776 __m512i s1 = _mm512_load_si512((__m512i*)sp);
777 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
778 __m512i d = _mm512_load_si512((__m512i*)dp);
779 __m512i t = _mm512_add_epi32(s1, s2);
780 __m512i w = _mm512_srai_epi32(t, e);
781 d = _mm512_sub_epi32(d, w);
782 _mm512_store_si512((__m512i*)dp, d);
783 }
784 else
785 for (; i > 0; i -= 16, sp += 16, dp += 16)
786 {
787 __m512i s1 = _mm512_load_si512((__m512i*)sp);
788 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
789 __m512i d = _mm512_load_si512((__m512i*)dp);
790 __m512i t = _mm512_add_epi32(s1, s2);
791 __m512i w = _mm512_srai_epi32(t, e);
792 d = _mm512_sub_epi32(d, w);
793 _mm512_store_si512((__m512i*)dp, d);
794 }
795 }
796 else if (a == -1)
797 { // any case with a == -1, which is not 5/3 predict
798 int i = (int)h_width;
799 if (even)
800 for (; i > 0; i -= 16, sp += 16, dp += 16)
801 {
802 __m512i s1 = _mm512_load_si512((__m512i*)sp);
803 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
804 __m512i d = _mm512_load_si512((__m512i*)dp);
805 __m512i t = _mm512_add_epi32(s1, s2);
806 __m512i v = _mm512_sub_epi32(vb, t);
807 __m512i w = _mm512_srai_epi32(v, e);
808 d = _mm512_add_epi32(d, w);
809 _mm512_store_si512((__m512i*)dp, d);
810 }
811 else
812 for (; i > 0; i -= 16, sp += 16, dp += 16)
813 {
814 __m512i s1 = _mm512_load_si512((__m512i*)sp);
815 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
816 __m512i d = _mm512_load_si512((__m512i*)dp);
817 __m512i t = _mm512_add_epi32(s1, s2);
818 __m512i v = _mm512_sub_epi32(vb, t);
819 __m512i w = _mm512_srai_epi32(v, e);
820 d = _mm512_add_epi32(d, w);
821 _mm512_store_si512((__m512i*)dp, d);
822 }
823 }
824 else {
825 // general case
826 int i = (int)h_width;
827 if (even)
828 for (; i > 0; i -= 16, sp += 16, dp += 16)
829 {
830 __m512i s1 = _mm512_load_si512((__m512i*)sp);
831 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
832 __m512i d = _mm512_load_si512((__m512i*)dp);
833 __m512i t = _mm512_add_epi32(s1, s2);
834 __m512i u = _mm512_mullo_epi32(va, t);
835 __m512i v = _mm512_add_epi32(vb, u);
836 __m512i w = _mm512_srai_epi32(v, e);
837 d = _mm512_add_epi32(d, w);
838 _mm512_store_si512((__m512i*)dp, d);
839 }
840 else
841 for (; i > 0; i -= 16, sp += 16, dp += 16)
842 {
843 __m512i s1 = _mm512_load_si512((__m512i*)sp);
844 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
845 __m512i d = _mm512_load_si512((__m512i*)dp);
846 __m512i t = _mm512_add_epi32(s1, s2);
847 __m512i u = _mm512_mullo_epi32(va, t);
848 __m512i v = _mm512_add_epi32(vb, u);
849 __m512i w = _mm512_srai_epi32(v, e);
850 d = _mm512_add_epi32(d, w);
851 _mm512_store_si512((__m512i*)dp, d);
852 }
853 }
854
855 // swap buffers
856 si32* t = lp; lp = hp; hp = t;
857 even = !even;
858 ui32 w = l_width; l_width = h_width; h_width = w;
859 }
860 }
861 else {
862 if (even)
863 ldst->i32[0] = src->i32[0];
864 else
865 hdst->i32[0] = src->i32[0] << 1;
866 }
867 }
868
870 void avx512_rev_horz_ana64(const param_atk* atk, const line_buf* ldst,
871 const line_buf* hdst, const line_buf* src,
872 ui32 width, bool even)
873 {
874 if (width > 1)
875 {
876 // split src into ldst and hdst
877 {
878 void* dpl = even ? ldst->p : hdst->p;
879 void* dph = even ? hdst->p : ldst->p;
880 const void* sp = src->p;
881 int w = (int)width;
882 avx512_deinterleave64(dpl, dph, sp, w);
883 }
884
885 si64* hp = hdst->i64, * lp = ldst->i64;
886 ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass
887 ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass
888 ui32 num_steps = atk->get_num_steps();
889 for (ui32 j = num_steps; j > 0; --j)
890 {
891 // first lifting step
892 const lifting_step* s = atk->get_step(j - 1);
893 const si32 a = s->rev.Aatk;
894 const si32 b = s->rev.Batk;
895 const ui8 e = s->rev.Eatk;
896 __m512i vb = _mm512_set1_epi64(b);
897
898 // extension
899 lp[-1] = lp[0];
900 lp[l_width] = lp[l_width - 1];
901 // lifting step
902 const si64* sp = lp;
903 si64* dp = hp;
904 if (a == 1)
905 { // 5/3 update and any case with a == 1
906 int i = (int)h_width;
907 if (even)
908 {
909 for (; i > 0; i -= 8, sp += 8, dp += 8)
910 {
911 __m512i s1 = _mm512_load_si512((__m512i*)sp);
912 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
913 __m512i d = _mm512_load_si512((__m512i*)dp);
914 __m512i t = _mm512_add_epi64(s1, s2);
915 __m512i v = _mm512_add_epi64(vb, t);
916 __m512i w = _mm512_srai_epi64(v, e);
917 d = _mm512_add_epi64(d, w);
918 _mm512_store_si512((__m512i*)dp, d);
919 }
920 }
921 else
922 {
923 for (; i > 0; i -= 8, sp += 8, dp += 8)
924 {
925 __m512i s1 = _mm512_load_si512((__m512i*)sp);
926 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
927 __m512i d = _mm512_load_si512((__m512i*)dp);
928 __m512i t = _mm512_add_epi64(s1, s2);
929 __m512i v = _mm512_add_epi64(vb, t);
930 __m512i w = _mm512_srai_epi64(v, e);
931 d = _mm512_add_epi64(d, w);
932 _mm512_store_si512((__m512i*)dp, d);
933 }
934 }
935 }
936 else if (a == -1 && b == 1 && e == 1)
937 { // 5/3 predict
938 int i = (int)h_width;
939 if (even)
940 for (; i > 0; i -= 8, sp += 8, dp += 8)
941 {
942 __m512i s1 = _mm512_load_si512((__m512i*)sp);
943 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
944 __m512i d = _mm512_load_si512((__m512i*)dp);
945 __m512i t = _mm512_add_epi64(s1, s2);
946 __m512i w = _mm512_srai_epi64(t, e);
947 d = _mm512_sub_epi64(d, w);
948 _mm512_store_si512((__m512i*)dp, d);
949 }
950 else
951 for (; i > 0; i -= 8, sp += 8, dp += 8)
952 {
953 __m512i s1 = _mm512_load_si512((__m512i*)sp);
954 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
955 __m512i d = _mm512_load_si512((__m512i*)dp);
956 __m512i t = _mm512_add_epi64(s1, s2);
957 __m512i w = _mm512_srai_epi64(t, e);
958 d = _mm512_sub_epi64(d, w);
959 _mm512_store_si512((__m512i*)dp, d);
960 }
961 }
962 else if (a == -1)
963 { // any case with a == -1, which is not 5/3 predict
964 int i = (int)h_width;
965 if (even)
966 for (; i > 0; i -= 8, sp += 8, dp += 8)
967 {
968 __m512i s1 = _mm512_load_si512((__m512i*)sp);
969 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
970 __m512i d = _mm512_load_si512((__m512i*)dp);
971 __m512i t = _mm512_add_epi64(s1, s2);
972 __m512i v = _mm512_sub_epi64(vb, t);
973 __m512i w = _mm512_srai_epi64(v, e);
974 d = _mm512_add_epi64(d, w);
975 _mm512_store_si512((__m512i*)dp, d);
976 }
977 else
978 for (; i > 0; i -= 8, sp += 8, dp += 8)
979 {
980 __m512i s1 = _mm512_load_si512((__m512i*)sp);
981 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
982 __m512i d = _mm512_load_si512((__m512i*)dp);
983 __m512i t = _mm512_add_epi64(s1, s2);
984 __m512i v = _mm512_sub_epi64(vb, t);
985 __m512i w = _mm512_srai_epi64(v, e);
986 d = _mm512_add_epi64(d, w);
987 _mm512_store_si512((__m512i*)dp, d);
988 }
989 }
990 else
991 {
992 // general case
993 // 64bit multiplication is not supported in AVX512F + AVX512CD;
994 // in particular, _mm256_mullo_epi64.
995 if (even)
996 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
997 *dp += (b + a * (sp[0] + sp[1])) >> e;
998 else
999 for (ui32 i = h_width; i > 0; --i, sp++, dp++)
1000 *dp += (b + a * (sp[-1] + sp[0])) >> e;
1001 }
1002
1003 // This can only be used if you have AVX512DQ
1004 // {
1005 // // general case
1006 // __m512i va = _mm512_set1_epi64(a);
1007 // int i = (int)h_width;
1008 // if (even)
1009 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1010 // {
1011 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1012 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1013 // __m512i d = _mm512_load_si512((__m512i*)dp);
1014 // __m512i t = _mm512_add_epi64(s1, s2);
1015 // __m512i u = _mm512_mullo_epi64(va, t);
1016 // __m512i v = _mm512_add_epi64(vb, u);
1017 // __m512i w = _mm512_srai_epi64(v, e);
1018 // d = _mm512_add_epi64(d, w);
1019 // _mm512_store_si512((__m512i*)dp, d);
1020 // }
1021 // else
1022 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1023 // {
1024 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1025 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1026 // __m512i d = _mm512_load_si512((__m512i*)dp);
1027 // __m512i t = _mm512_add_epi64(s1, s2);
1028 // __m512i u = _mm512_mullo_epi64(va, t);
1029 // __m512i v = _mm512_add_epi64(vb, u);
1030 // __m512i w = _mm512_srai_epi64(v, e);
1031 // d = _mm512_add_epi64(d, w);
1032 // _mm512_store_si512((__m512i*)dp, d);
1033 // }
1034 // }
1035
1036 // swap buffers
1037 si64* t = lp; lp = hp; hp = t;
1038 even = !even;
1039 ui32 w = l_width; l_width = h_width; h_width = w;
1040 }
1041 }
1042 else {
1043 if (even)
1044 ldst->i64[0] = src->i64[0];
1045 else
1046 hdst->i64[0] = src->i64[0] << 1;
1047 }
1048 }
1049
1051 void avx512_rev_horz_ana(const param_atk* atk, const line_buf* ldst,
1052 const line_buf* hdst, const line_buf* src,
1053 ui32 width, bool even)
1054 {
1055 if (src->flags & line_buf::LFT_32BIT)
1056 {
1057 assert((ldst == NULL || ldst->flags & line_buf::LFT_32BIT) &&
1058 (hdst == NULL || hdst->flags & line_buf::LFT_32BIT));
1059 avx512_rev_horz_ana32(atk, ldst, hdst, src, width, even);
1060 }
1061 else
1062 {
1063 assert((ldst == NULL || ldst->flags & line_buf::LFT_64BIT) &&
1064 (hdst == NULL || hdst->flags & line_buf::LFT_64BIT) &&
1065 (src == NULL || src->flags & line_buf::LFT_64BIT));
1066 avx512_rev_horz_ana64(atk, ldst, hdst, src, width, even);
1067 }
1068 }
1069
1071 void avx512_rev_horz_syn32(const param_atk* atk, const line_buf* dst,
1072 const line_buf* lsrc, const line_buf* hsrc,
1073 ui32 width, bool even)
1074 {
1075 if (width > 1)
1076 {
1077 bool ev = even;
1078 si32* oth = hsrc->i32, * aug = lsrc->i32;
1079 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1080 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1081 ui32 num_steps = atk->get_num_steps();
1082 for (ui32 j = 0; j < num_steps; ++j)
1083 {
1084 const lifting_step* s = atk->get_step(j);
1085 const si32 a = s->rev.Aatk;
1086 const si32 b = s->rev.Batk;
1087 const ui8 e = s->rev.Eatk;
1088 __m512i va = _mm512_set1_epi32(a);
1089 __m512i vb = _mm512_set1_epi32(b);
1090
1091 // extension
1092 oth[-1] = oth[0];
1093 oth[oth_width] = oth[oth_width - 1];
1094 // lifting step
1095 const si32* sp = oth;
1096 si32* dp = aug;
1097 if (a == 1)
1098 { // 5/3 update and any case with a == 1
1099 int i = (int)aug_width;
1100 if (ev)
1101 {
1102 for (; i > 0; i -= 16, sp += 16, dp += 16)
1103 {
1104 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1105 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1106 __m512i d = _mm512_load_si512((__m512i*)dp);
1107 __m512i t = _mm512_add_epi32(s1, s2);
1108 __m512i v = _mm512_add_epi32(vb, t);
1109 __m512i w = _mm512_srai_epi32(v, e);
1110 d = _mm512_sub_epi32(d, w);
1111 _mm512_store_si512((__m512i*)dp, d);
1112 }
1113 }
1114 else
1115 {
1116 for (; i > 0; i -= 16, sp += 16, dp += 16)
1117 {
1118 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1119 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1120 __m512i d = _mm512_load_si512((__m512i*)dp);
1121 __m512i t = _mm512_add_epi32(s1, s2);
1122 __m512i v = _mm512_add_epi32(vb, t);
1123 __m512i w = _mm512_srai_epi32(v, e);
1124 d = _mm512_sub_epi32(d, w);
1125 _mm512_store_si512((__m512i*)dp, d);
1126 }
1127 }
1128 }
1129 else if (a == -1 && b == 1 && e == 1)
1130 { // 5/3 predict
1131 int i = (int)aug_width;
1132 if (ev)
1133 for (; i > 0; i -= 16, sp += 16, dp += 16)
1134 {
1135 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1136 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1137 __m512i d = _mm512_load_si512((__m512i*)dp);
1138 __m512i t = _mm512_add_epi32(s1, s2);
1139 __m512i w = _mm512_srai_epi32(t, e);
1140 d = _mm512_add_epi32(d, w);
1141 _mm512_store_si512((__m512i*)dp, d);
1142 }
1143 else
1144 for (; i > 0; i -= 16, sp += 16, dp += 16)
1145 {
1146 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1147 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1148 __m512i d = _mm512_load_si512((__m512i*)dp);
1149 __m512i t = _mm512_add_epi32(s1, s2);
1150 __m512i w = _mm512_srai_epi32(t, e);
1151 d = _mm512_add_epi32(d, w);
1152 _mm512_store_si512((__m512i*)dp, d);
1153 }
1154 }
1155 else if (a == -1)
1156 { // any case with a == -1, which is not 5/3 predict
1157 int i = (int)aug_width;
1158 if (ev)
1159 for (; i > 0; i -= 16, sp += 16, dp += 16)
1160 {
1161 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1162 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1163 __m512i d = _mm512_load_si512((__m512i*)dp);
1164 __m512i t = _mm512_add_epi32(s1, s2);
1165 __m512i v = _mm512_sub_epi32(vb, t);
1166 __m512i w = _mm512_srai_epi32(v, e);
1167 d = _mm512_sub_epi32(d, w);
1168 _mm512_store_si512((__m512i*)dp, d);
1169 }
1170 else
1171 for (; i > 0; i -= 16, sp += 16, dp += 16)
1172 {
1173 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1174 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1175 __m512i d = _mm512_load_si512((__m512i*)dp);
1176 __m512i t = _mm512_add_epi32(s1, s2);
1177 __m512i v = _mm512_sub_epi32(vb, t);
1178 __m512i w = _mm512_srai_epi32(v, e);
1179 d = _mm512_sub_epi32(d, w);
1180 _mm512_store_si512((__m512i*)dp, d);
1181 }
1182 }
1183 else {
1184 // general case
1185 int i = (int)aug_width;
1186 if (ev)
1187 for (; i > 0; i -= 16, sp += 16, dp += 16)
1188 {
1189 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1190 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1191 __m512i d = _mm512_load_si512((__m512i*)dp);
1192 __m512i t = _mm512_add_epi32(s1, s2);
1193 __m512i u = _mm512_mullo_epi32(va, t);
1194 __m512i v = _mm512_add_epi32(vb, u);
1195 __m512i w = _mm512_srai_epi32(v, e);
1196 d = _mm512_sub_epi32(d, w);
1197 _mm512_store_si512((__m512i*)dp, d);
1198 }
1199 else
1200 for (; i > 0; i -= 16, sp += 16, dp += 16)
1201 {
1202 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1203 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1204 __m512i d = _mm512_load_si512((__m512i*)dp);
1205 __m512i t = _mm512_add_epi32(s1, s2);
1206 __m512i u = _mm512_mullo_epi32(va, t);
1207 __m512i v = _mm512_add_epi32(vb, u);
1208 __m512i w = _mm512_srai_epi32(v, e);
1209 d = _mm512_sub_epi32(d, w);
1210 _mm512_store_si512((__m512i*)dp, d);
1211 }
1212 }
1213
1214 // swap buffers
1215 si32* t = aug; aug = oth; oth = t;
1216 ev = !ev;
1217 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1218 }
1219
1220 // combine both lsrc and hsrc into dst
1221 {
1222 float* dp = dst->f32;
1223 float* spl = even ? lsrc->f32 : hsrc->f32;
1224 float* sph = even ? hsrc->f32 : lsrc->f32;
1225 int w = (int)width;
1226 avx512_interleave32(dp, spl, sph, w);
1227 }
1228 }
1229 else {
1230 if (even)
1231 dst->i32[0] = lsrc->i32[0];
1232 else
1233 dst->i32[0] = hsrc->i32[0] >> 1;
1234 }
1235 }
1236
1238 void avx512_rev_horz_syn64(const param_atk* atk, const line_buf* dst,
1239 const line_buf* lsrc, const line_buf* hsrc,
1240 ui32 width, bool even)
1241 {
1242 if (width > 1)
1243 {
1244 bool ev = even;
1245 si64* oth = hsrc->i64, * aug = lsrc->i64;
1246 ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass
1247 ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass
1248 ui32 num_steps = atk->get_num_steps();
1249 for (ui32 j = 0; j < num_steps; ++j)
1250 {
1251 const lifting_step* s = atk->get_step(j);
1252 const si32 a = s->rev.Aatk;
1253 const si32 b = s->rev.Batk;
1254 const ui8 e = s->rev.Eatk;
1255 __m512i vb = _mm512_set1_epi64(b);
1256
1257 // extension
1258 oth[-1] = oth[0];
1259 oth[oth_width] = oth[oth_width - 1];
1260 // lifting step
1261 const si64* sp = oth;
1262 si64* dp = aug;
1263 if (a == 1)
1264 { // 5/3 update and any case with a == 1
1265 int i = (int)aug_width;
1266 if (ev)
1267 {
1268 for (; i > 0; i -= 8, sp += 8, dp += 8)
1269 {
1270 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1271 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1272 __m512i d = _mm512_load_si512((__m512i*)dp);
1273 __m512i t = _mm512_add_epi64(s1, s2);
1274 __m512i v = _mm512_add_epi64(vb, t);
1275 __m512i w = _mm512_srai_epi64(v, e);
1276 d = _mm512_sub_epi64(d, w);
1277 _mm512_store_si512((__m512i*)dp, d);
1278 }
1279 }
1280 else
1281 {
1282 for (; i > 0; i -= 8, sp += 8, dp += 8)
1283 {
1284 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1285 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1286 __m512i d = _mm512_load_si512((__m512i*)dp);
1287 __m512i t = _mm512_add_epi64(s1, s2);
1288 __m512i v = _mm512_add_epi64(vb, t);
1289 __m512i w = _mm512_srai_epi64(v, e);
1290 d = _mm512_sub_epi64(d, w);
1291 _mm512_store_si512((__m512i*)dp, d);
1292 }
1293 }
1294 }
1295 else if (a == -1 && b == 1 && e == 1)
1296 { // 5/3 predict
1297 int i = (int)aug_width;
1298 if (ev)
1299 for (; i > 0; i -= 8, sp += 8, dp += 8)
1300 {
1301 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1302 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1303 __m512i d = _mm512_load_si512((__m512i*)dp);
1304 __m512i t = _mm512_add_epi64(s1, s2);
1305 __m512i w = _mm512_srai_epi64(t, e);
1306 d = _mm512_add_epi64(d, w);
1307 _mm512_store_si512((__m512i*)dp, d);
1308 }
1309 else
1310 for (; i > 0; i -= 8, sp += 8, dp += 8)
1311 {
1312 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1313 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1314 __m512i d = _mm512_load_si512((__m512i*)dp);
1315 __m512i t = _mm512_add_epi64(s1, s2);
1316 __m512i w = _mm512_srai_epi64(t, e);
1317 d = _mm512_add_epi64(d, w);
1318 _mm512_store_si512((__m512i*)dp, d);
1319 }
1320 }
1321 else if (a == -1)
1322 { // any case with a == -1, which is not 5/3 predict
1323 int i = (int)aug_width;
1324 if (ev)
1325 for (; i > 0; i -= 8, sp += 8, dp += 8)
1326 {
1327 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1328 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1329 __m512i d = _mm512_load_si512((__m512i*)dp);
1330 __m512i t = _mm512_add_epi64(s1, s2);
1331 __m512i v = _mm512_sub_epi64(vb, t);
1332 __m512i w = _mm512_srai_epi64(v, e);
1333 d = _mm512_sub_epi64(d, w);
1334 _mm512_store_si512((__m512i*)dp, d);
1335 }
1336 else
1337 for (; i > 0; i -= 8, sp += 8, dp += 8)
1338 {
1339 __m512i s1 = _mm512_load_si512((__m512i*)sp);
1340 __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1341 __m512i d = _mm512_load_si512((__m512i*)dp);
1342 __m512i t = _mm512_add_epi64(s1, s2);
1343 __m512i v = _mm512_sub_epi64(vb, t);
1344 __m512i w = _mm512_srai_epi64(v, e);
1345 d = _mm512_sub_epi64(d, w);
1346 _mm512_store_si512((__m512i*)dp, d);
1347 }
1348 }
1349 else
1350 {
1351 // general case
1352 // 64bit multiplication is not supported in AVX512F + AVX512CD;
1353 // in particular, _mm256_mullo_epi64.
1354 if (ev)
1355 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1356 *dp -= (b + a * (sp[-1] + sp[0])) >> e;
1357 else
1358 for (ui32 i = aug_width; i > 0; --i, sp++, dp++)
1359 *dp -= (b + a * (sp[0] + sp[1])) >> e;
1360 }
1361
1362 // This can only be used if you have AVX512DQ
1363 // {
1364 // // general case
1365 // __m512i va = _mm512_set1_epi64(a);
1366 // int i = (int)aug_width;
1367 // if (ev)
1368 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1369 // {
1370 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1371 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1));
1372 // __m512i d = _mm512_load_si512((__m512i*)dp);
1373 // __m512i t = _mm512_add_epi64(s1, s2);
1374 // __m512i u = _mm512_mullo_epi64(va, t);
1375 // __m512i v = _mm512_add_epi64(vb, u);
1376 // __m512i w = _mm512_srai_epi64(v, e);
1377 // d = _mm512_sub_epi64(d, w);
1378 // _mm512_store_si512((__m512i*)dp, d);
1379 // }
1380 // else
1381 // for (; i > 0; i -= 8, sp += 8, dp += 8)
1382 // {
1383 // __m512i s1 = _mm512_load_si512((__m512i*)sp);
1384 // __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1));
1385 // __m512i d = _mm512_load_si512((__m512i*)dp);
1386 // __m512i t = _mm512_add_epi64(s1, s2);
1387 // __m512i u = _mm512_mullo_epi64(va, t);
1388 // __m512i v = _mm512_add_epi64(vb, u);
1389 // __m512i w = _mm512_srai_epi64(v, e);
1390 // d = _mm512_sub_epi64(d, w);
1391 // _mm512_store_si512((__m512i*)dp, d);
1392 // }
1393 // }
1394
1395 // swap buffers
1396 si64* t = aug; aug = oth; oth = t;
1397 ev = !ev;
1398 ui32 w = aug_width; aug_width = oth_width; oth_width = w;
1399 }
1400
1401 // combine both lsrc and hsrc into dst
1402 {
1403 void* dp = dst->p;
1404 const void* spl = even ? lsrc->p : hsrc->p;
1405 const void* sph = even ? hsrc->p : lsrc->p;
1406 int w = (int)width;
1407 avx512_interleave64(dp, spl, sph, w);
1408 }
1409 }
1410 else {
1411 if (even)
1412 dst->i64[0] = lsrc->i64[0];
1413 else
1414 dst->i64[0] = hsrc->i64[0] >> 1;
1415 }
1416 }
1417
1419 void avx512_rev_horz_syn(const param_atk* atk, const line_buf* dst,
1420 const line_buf* lsrc, const line_buf* hsrc,
1421 ui32 width, bool even)
1422 {
1423 if (dst->flags & line_buf::LFT_32BIT)
1424 {
1425 assert((lsrc == NULL || lsrc->flags & line_buf::LFT_32BIT) &&
1426 (hsrc == NULL || hsrc->flags & line_buf::LFT_32BIT));
1427 avx512_rev_horz_syn32(atk, dst, lsrc, hsrc, width, even);
1428 }
1429 else
1430 {
1431 assert((dst == NULL || dst->flags & line_buf::LFT_64BIT) &&
1432 (lsrc == NULL || lsrc->flags & line_buf::LFT_64BIT) &&
1433 (hsrc == NULL || hsrc->flags & line_buf::LFT_64BIT));
1434 avx512_rev_horz_syn64(atk, dst, lsrc, hsrc, width, even);
1435 }
1436 }
1437
1438 } // !local
1439} // !ojph
1440
1441#endif
void avx512_irv_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_rev_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_vert_step(const lifting_step *s, const line_buf *sig, const line_buf *other, const line_buf *aug, ui32 repeat, bool synthesis)
void avx512_irv_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
void avx512_irv_vert_times_K(float K, const line_buf *aug, ui32 repeat)
void avx512_irv_horz_syn(const param_atk *atk, const line_buf *dst, const line_buf *lsrc, const line_buf *hsrc, ui32 width, bool even)
void avx512_rev_horz_ana(const param_atk *atk, const line_buf *ldst, const line_buf *hdst, const line_buf *src, ui32 width, bool even)
int64_t si64
Definition ojph_defs.h:57
int32_t si32
Definition ojph_defs.h:55
uint32_t ui32
Definition ojph_defs.h:54
uint8_t ui8
Definition ojph_defs.h:50