maomao90's Library
A C++20 library for competitive programming.
Loading...
Searching...
No Matches
poly.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <bit>
6#include <complex>
7#include <type_traits>
8#include <vector>
9
14
15namespace maomao90 {
16using namespace std;
17namespace internal::poly {
18template <class T>
19vector<T> convolution_naive(const vector<T> &a, const vector<T> &b) {
20 int n = a.size(), m = b.size();
21 vector<T> ans(n + m - 1);
22 if (n < m) {
23 for (int j = 0; j < m; j++) {
24 for (int i = 0; i < n; i++) {
25 ans[i + j] += a[i] * b[j];
26 }
27 }
28 } else {
29 for (int i = 0; i < n; i++) {
30 for (int j = 0; j < m; j++) {
31 ans[i + j] += a[i] * b[j];
32 }
33 }
34 }
35 return ans;
36}
37namespace ntt {
38template <StaticModInt mint,
39 int g = internal::math::primitive_root<mint::imod()>>
40struct ntt_info {
41 static constexpr int rank2 = __builtin_ctz(mint::umod() - 1);
42 array<mint, rank2 + 1> root; // root[i]^(2^i) == 1
43 array<mint, rank2 + 1> iroot; // root[i] * iroot[i] == 1
44
45 array<mint, max(0, rank2 - 2 + 1)> rate2;
46 array<mint, max(0, rank2 - 2 + 1)> irate2;
47
48 array<mint, max(0, rank2 - 3 + 1)> rate3;
49 array<mint, max(0, rank2 - 3 + 1)> irate3;
50
52 root[rank2] = mint(g).pow((mint::imod() - 1) >> rank2);
53 iroot[rank2] = root[rank2].inv();
54 for (int i = rank2 - 1; i >= 0; i--) {
55 root[i] = root[i + 1] * root[i + 1];
56 iroot[i] = iroot[i + 1] * iroot[i + 1];
57 }
58
59 {
60 mint prod = 1, iprod = 1;
61 for (int i = 0; i <= rank2 - 2; i++) {
62 rate2[i] = root[i + 2] * prod;
63 irate2[i] = iroot[i + 2] * iprod;
64 prod *= iroot[i + 2];
65 iprod *= root[i + 2];
66 }
67 }
68 {
69 mint prod = 1, iprod = 1;
70 for (int i = 0; i <= rank2 - 3; i++) {
71 rate3[i] = root[i + 3] * prod;
72 irate3[i] = iroot[i + 3] * iprod;
73 prod *= iroot[i + 3];
74 iprod *= root[i + 3];
75 }
76 }
77 }
78};
79
80template <StaticModInt mint> void butterfly(vector<mint> &a) {
81 int n = a.size();
82 int h = __builtin_ctz((unsigned int)n);
83
84 static const ntt_info<mint> info;
85
86 int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
87 while (len < h) {
88 if (h - len == 1) {
89 int p = 1 << (h - len - 1);
90 mint rot = 1;
91 for (int s = 0; s < (1 << len); s++) {
92 int offset = s << (h - len);
93 for (int i = 0; i < p; i++) {
94 auto l = a[i + offset];
95 auto r = a[i + offset + p] * rot;
96 a[i + offset] = l + r;
97 a[i + offset + p] = l - r;
98 }
99 if (s + 1 != (1 << len)) {
100 rot *= info.rate2[countr_zero(~(unsigned int)(s))];
101 }
102 }
103 len++;
104 } else {
105 // 4-base
106 int p = 1 << (h - len - 2);
107 mint rot = 1, imag = info.root[2];
108 for (int s = 0; s < (1 << len); s++) {
109 mint rot2 = rot * rot;
110 mint rot3 = rot2 * rot;
111 int offset = s << (h - len);
112 for (int i = 0; i < p; i++) {
113 auto mod2 = 1ull * mint::imod() * mint::imod();
114 auto a0 = 1ull * a[i + offset].val();
115 auto a1 = 1ull * a[i + offset + p].val() * rot.val();
116 auto a2 = 1ull * a[i + offset + 2 * p].val() * rot2.val();
117 auto a3 = 1ull * a[i + offset + 3 * p].val() * rot3.val();
118 auto a1na3imag = 1ull * mint(a1 + mod2 - a3).val() * imag.val();
119 auto na2 = mod2 - a2;
120 a[i + offset] = a0 + a2 + a1 + a3;
121 a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
122 a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
123 a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
124 }
125 if (s + 1 != (1 << len)) {
126 rot *= info.rate3[countr_zero(~(unsigned int)(s))];
127 }
128 }
129 len += 2;
130 }
131 }
132}
133
134template <StaticModInt mint> void butterfly_inv(vector<mint> &a) {
135 int n = a.size();
136 int h = __builtin_ctz((unsigned int)n);
137
138 static const ntt_info<mint> info;
139
140 int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
141 while (len) {
142 if (len == 1) {
143 int p = 1 << (h - len);
144 mint irot = 1;
145 for (int s = 0; s < (1 << (len - 1)); s++) {
146 int offset = s << (h - len + 1);
147 for (int i = 0; i < p; i++) {
148 auto l = a[i + offset];
149 auto r = a[i + offset + p];
150 a[i + offset] = l + r;
151 a[i + offset + p] =
152 (unsigned long long)(mint::imod() + l.val() - r.val()) *
153 irot.val();
154 ;
155 }
156 if (s + 1 != (1 << (len - 1))) {
157 irot *= info.irate2[countr_zero(~(unsigned int)(s))];
158 }
159 }
160 len--;
161 } else {
162 // 4-base
163 int p = 1 << (h - len);
164 mint irot = 1, iimag = info.iroot[2];
165 for (int s = 0; s < (1 << (len - 2)); s++) {
166 mint irot2 = irot * irot;
167 mint irot3 = irot2 * irot;
168 int offset = s << (h - len + 2);
169 for (int i = 0; i < p; i++) {
170 auto a0 = 1ull * a[i + offset + 0 * p].val();
171 auto a1 = 1ull * a[i + offset + 1 * p].val();
172 auto a2 = 1ull * a[i + offset + 2 * p].val();
173 auto a3 = 1ull * a[i + offset + 3 * p].val();
174
175 auto a2na3iimag =
176 1ull * mint((mint::imod() + a2 - a3) * iimag.val()).val();
177
178 a[i + offset] = a0 + a1 + a2 + a3;
179 a[i + offset + 1 * p] =
180 (a0 + (mint::imod() - a1) + a2na3iimag) * irot.val();
181 a[i + offset + 2 * p] =
182 (a0 + a1 + (mint::imod() - a2) + (mint::imod() - a3)) *
183 irot2.val();
184 a[i + offset + 3 * p] =
185 (a0 + (mint::imod() - a1) + (mint::imod() - a2na3iimag)) *
186 irot3.val();
187 }
188 if (s + 1 != (1 << (len - 2))) {
189 irot *= info.irate3[countr_zero(~(unsigned int)(s))];
190 }
191 }
192 len -= 2;
193 }
194 }
195}
196template <StaticModInt mint>
197vector<mint> convolution_ntt(vector<mint> a, vector<mint> b) {
198 int n = a.size(), m = b.size();
199 int z = bit_ceil((unsigned int)(n + m - 1));
200 a.resize(z);
201 butterfly(a);
202 b.resize(z);
203 butterfly(b);
204 for (int i = 0; i < z; i++) {
205 a[i] *= b[i];
206 }
207 butterfly_inv(a);
208 a.resize(n + m - 1);
209 mint iz = mint(z).inv();
210 for (int i = 0; i < n + m - 1; i++)
211 a[i] *= iz;
212 return a;
213}
214template <StaticModInt mint>
215vector<mint> convolution(const vector<mint> &a, const vector<mint> &b) {
216 int n = a.size(), m = b.size();
217 if (!n || !m) {
218 return {};
219 }
220
221 int z = bit_ceil((unsigned int)(n + m - 1));
222 assert((mint::imod() - 1) % z == 0);
223
224 if (min(n, m) <= 60) {
225 return convolution_naive(a, b);
226 }
227 return convolution_ntt(a, b);
228}
229template <int mod, class T>
230 requires is_integral_v<T>
231vector<T> convolution(const vector<T> &a, const vector<T> &b) {
232 int n = a.size(), m = b.size();
233 if (!n || !m)
234 return {};
235
236 using mint = static_modint<mod>;
237
238 int z = bit_ceil((unsigned int)(n + m - 1));
239 assert((mint::imod() - 1) % z == 0);
240
241 vector<mint> a2(n), b2(m);
242 for (int i = 0; i < n; i++) {
243 a2[i] = mint(a[i]);
244 }
245 for (int i = 0; i < m; i++) {
246 b2[i] = mint(b[i]);
247 }
248 auto c2 = convolution(std::move(a2), std::move(b2));
249 vector<T> c(n + m - 1);
250 for (int i = 0; i < n + m - 1; i++) {
251 c[i] = c2[i].val();
252 }
253 return c;
254}
255inline vector<long long> convolution_ll(const vector<long long> &a,
256 const vector<long long> &b) {
257 int n = a.size(), m = b.size();
258 if (!n || !m) {
259 return {};
260 }
261
262 if (min(n, m) <= 60) {
263 return convolution_naive(a, b);
264 }
265
266 static constexpr unsigned long long MOD1 = 754974721; // 2^24
267 static constexpr unsigned long long MOD2 = 167772161; // 2^25
268 static constexpr unsigned long long MOD3 = 469762049; // 2^26
269 static constexpr unsigned long long M2M3 = MOD2 * MOD3;
270 static constexpr unsigned long long M1M3 = MOD1 * MOD3;
271 static constexpr unsigned long long M1M2 = MOD1 * MOD2;
272 static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
273
274 static constexpr unsigned long long i1 =
275 inv_gcd<long long>(MOD2 * MOD3, MOD1);
276 static constexpr unsigned long long i2 =
277 inv_gcd<long long>(MOD1 * MOD3, MOD2);
278 static constexpr unsigned long long i3 =
279 inv_gcd<long long>(MOD1 * MOD2, MOD3);
280
281 static constexpr int MAX_AB_BIT = 24;
282 static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1,
283 "MOD1 isn't enough to support an array length of 2^24.");
284 static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1,
285 "MOD2 isn't enough to support an array length of 2^24.");
286 static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1,
287 "MOD3 isn't enough to support an array length of 2^24.");
288 assert(n + m - 1 <= (1 << MAX_AB_BIT));
289
290 vector<long long> c1 = convolution<MOD1>(a, b);
291 vector<long long> c2 = convolution<MOD2>(a, b);
292 vector<long long> c3 = convolution<MOD3>(a, b);
293
294 vector<long long> c(n + m - 1);
295 for (int i = 0; i < n + m - 1; i++) {
296 unsigned long long x = 0;
297 x += (c1[i] * i1) % MOD1 * M2M3;
298 x += (c2[i] * i2) % MOD2 * M1M3;
299 x += (c3[i] * i3) % MOD3 * M1M2;
300 // B = 2^63, -B <= x, r(real value) < B
301 // (x, x - M, x - 2M, or x - 3M) = r (mod 2B)
302 // r = c1[i] (mod MOD1)
303 // focus on MOD1
304 // r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B)
305 // r = x,
306 // x - M' + (0 or 2B),
307 // x - 2M' + (0, 2B or 4B),
308 // x - 3M' + (0, 2B, 4B or 6B) (without mod!)
309 // (r - x) = 0, (0)
310 // - M' + (0 or 2B), (1)
311 // -2M' + (0 or 2B or 4B), (2)
312 // -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1)
313 // we checked that
314 // ((1) mod MOD1) mod 5 = 2
315 // ((2) mod MOD1) mod 5 = 3
316 // ((3) mod MOD1) mod 5 = 4
317 long long _x = (long long)x % (long long)MOD1;
318 if (_x < 0) {
319 _x += MOD1;
320 }
321 long long diff = c1[i] - _x;
322 if (diff < 0)
323 diff += MOD1;
324 static constexpr unsigned long long offset[5] = {0, 0, M1M2M3, 2 * M1M2M3,
325 3 * M1M2M3};
326 x -= offset[diff % 5];
327 c[i] = x;
328 }
329
330 return c;
331}
332template <ModInt mint>
333inline vector<mint> convolution_arb_mod(const vector<mint> &a,
334 const vector<mint> &b) {
335 int n = a.size(), m = b.size();
336 if (!n || !m) {
337 return {};
338 }
339
340 if (min(n, m) <= 60) {
341 return convolution_naive(a, b);
342 }
343
344 static constexpr long long MOD1 = 167772161; // 2^25
345 static constexpr long long MOD2 = 469762049; // 2^26
346 static constexpr long long MOD3 = 754974721; // 2^24
347
348 static constexpr long long INV12 = inv_gcd(MOD1, MOD2);
349 static constexpr long long INV13 = inv_gcd(MOD1, MOD3);
350 static constexpr long long INV23 = inv_gcd(MOD2, MOD3);
351 static constexpr long long INV13INV23 = INV13 * INV23 % MOD3;
352 static constexpr long long W1 = MOD1 % mint::imod();
353 static constexpr long long W2 = W1 * MOD2 % mint::imod();
354
355 static constexpr int MAX_AB_BIT = 24;
356 static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1,
357 "MOD1 isn't enough to support an array length of 2^24.");
358 static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1,
359 "MOD2 isn't enough to support an array length of 2^24.");
360 static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1,
361 "MOD3 isn't enough to support an array length of 2^24.");
362 assert(n + m - 1 <= (1 << MAX_AB_BIT));
363
364 vector<long long> _a(n), _b(m);
365 for (int i = 0; i < n; i++) {
366 _a[i] = a[i].val();
367 }
368 for (int i = 0; i < m; i++) {
369 _b[i] = b[i].val();
370 }
371
372 vector<long long> c1 = convolution<MOD1>(_a, _b);
373 vector<long long> c2 = convolution<MOD2>(_a, _b);
374 vector<long long> c3 = convolution<MOD3>(_a, _b);
375
376 vector<mint> c(n + m - 1);
377 for (int i = 0; i < n + m - 1; i++) {
378 long long x = (c2[i] + MOD2 - c1[i]) * INV12 % MOD2;
379 long long y =
380 ((c3[i] + MOD3 - c1[i]) * INV13INV23 + (MOD3 - x) * INV23) % MOD3;
381 c[i] = c1[i] + x * W1 + y * W2;
382 }
383
384 return c;
385}
386} // namespace ntt
387namespace fft {
388template <typename T>
389 requires is_floating_point_v<T>
390inline void fft(vector<complex<T>> &a) {
391 int n = a.size(), L = 31 - __builtin_clz(n);
392 static vector<complex<long double>> R(2, 1);
393 static vector<complex<T>> rt(2, 1); // (^ 10% faster if double)
394 for (static int k = 2; k < n; k *= 2) {
395 R.resize(n);
396 rt.resize(n);
397 auto x = polar(1.0L, acos(-1.0L) / k);
398 for (int i = k; i < 2 * k; i++) {
399 rt[i] = R[i] = i & 1 ? R[i / 2] * x : R[i / 2];
400 }
401 }
402 vector<int> rev(n);
403 for (int i = 0; i < n; i++) {
404 rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
405 }
406 for (int i = 0; i < n; i++) {
407 if (i < rev[i]) {
408 swap(a[i], a[rev[i]]);
409 }
410 }
411 for (int k = 1; k < n; k *= 2) {
412 for (int i = 0; i < n; i += 2 * k) {
413 for (int j = 0; j < k; j++) {
414 // complex<T> z = rt[j+k] * a[i+j+k]; // (25% faster if hand-rolled)
415 auto x = (T *)&rt[j + k], y = (T *)&a[i + j + k];
416 complex<T> z(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]);
417 a[i + j + k] = a[i + j] - z;
418 a[i + j] += z;
419 }
420 }
421 }
422}
423template <typename T>
424 requires is_arithmetic_v<T>
425inline vector<T> convolution(const vector<T> &ta, const vector<T> &tb) {
426 int n = ta.size(), m = tb.size();
427 if (!n || !m) {
428 return {};
429 }
430 if (min(n, m) <= 60) {
431 return convolution_naive(ta, tb);
432 }
433 vector<double> a(n), b(m);
434 for (int i = 0; i < n; i++) {
435 a[i] = ta[i];
436 }
437 for (int i = 0; i < m; i++) {
438 b[i] = tb[i];
439 }
440 int z = bit_ceil((unsigned int)(n + m - 1));
441 vector<complex<double>> in(z), out(z);
442 copy(a.begin(), a.end(), in.begin());
443 for (int i = 0; i < m; i++) {
444 in[i].imag(b[i]);
445 }
446 fft(in);
447 for (complex<double> &x : in) {
448 x *= x;
449 }
450 for (int i = 0; i < z; i++) {
451 out[i] = in[-i & (z - 1)] - conj(in[i]);
452 }
453 fft(out);
454 vector<T> res(n + m - 1);
455 for (int i = 0; i < n + m - 1; i++) {
456 if constexpr (integral<T>) {
457 res[i] = imag(out[i]) / (4 * z) + 0.5;
458 } else {
459 res[i] = imag(out[i]) / (4 * z);
460 }
461 }
462 return res;
463}
464template <concepts::broadly_integral T>
465inline vector<T> convolution_sqrt(const vector<T> &a, const vector<T> &b) {
466 int n = a.size(), m = b.size();
467 if (!n || !m) {
468 return {};
469 }
470 if (min(n, m) <= 60) {
471 return convolution_naive(a, b);
472 }
473 int z = bit_ceil((unsigned int)(n + m - 1)), cut = 1 << 15;
474 vector<complex<double>> L(z), R(z), outs(z), outl(z);
475 for (int i = 0; i < n; i++) {
476 L[i] = complex<double>(a[i] >> 15, a[i] & ((1 << 15) - 1));
477 }
478 for (int i = 0; i < m; i++) {
479 R[i] = complex<double>(b[i] >> 15, b[i] & ((1 << 15) - 1));
480 }
481 fft(L), fft(R);
482 for (int i = 0; i < z; i++) {
483 int j = -i & (z - 1);
484 outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * z);
485 outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * z) / 1i;
486 }
487 fft(outl), fft(outs);
488 vector<T> res(n + m - 1);
489 for (int i = 0; i < n + m - 1; i++) {
490 T av = (T)(real(outl[i]) + .5), cv = (T)(imag(outs[i]) + .5);
491 T bv = (T)(imag(outl[i]) + .5) + (T)(real(outs[i]) + .5);
492 res[i] = (av * cut + bv) * cut + cv;
493 }
494 return res;
495}
496template <ModInt mint>
497vector<mint> convolution_arb_mod(const vector<mint> &a, const vector<mint> &b) {
498 int n = a.size(), m = b.size();
499 if (!n || !m) {
500 return {};
501 }
502 if (min(n, m) <= 60) {
503 return convolution_naive(a, b);
504 }
505 int z = bit_ceil((unsigned int)(n + m - 1)), cut = int(sqrt(mint::imod()));
506 vector<complex<double>> L(z), R(z), outs(z), outl(z);
507 for (int i = 0; i < n; i++) {
508 L[i] = complex<double>(a[i].val() / cut, a[i].val() % cut);
509 }
510 for (int i = 0; i < m; i++) {
511 R[i] = complex<double>(b[i].val() / cut, b[i].val() % cut);
512 }
513 fft(L), fft(R);
514 for (int i = 0; i < z; i++) {
515 int j = -i & (z - 1);
516 outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * z);
517 outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * z) / 1i;
518 }
519 fft(outl), fft(outs);
520 vector<mint> res(n + m - 1);
521 for (int i = 0; i < n + m - 1; i++) {
522 long long av = (long long)(real(outl[i]) + .5),
523 cv = (long long)(imag(outs[i]) + .5);
524 long long bv =
525 (long long)(imag(outl[i]) + .5) + (long long)(real(outs[i]) + .5);
526 res[i] = (av % mint::imod() * cut + bv) % mint::imod() * cut + cv;
527 }
528 return res;
529}
530template <typename T>
531 requires is_floating_point_v<T>
532vector<complex<T>> convolution_complex(const vector<complex<T>> &a,
533 const vector<complex<T>> &b) {
534 int n = a.size(), m = b.size();
535 if (!a || !b) {
536 return {};
537 }
538 if (min(n, m) <= 60) {
539 return convolution_naive(a, b);
540 }
541 int z = bit_ceil((unsigned int)(n + m - 1));
542 a.resize(z, 0);
543 b.resize(z, 0);
544 fft(a);
545 fft(b);
546 for (int i = 0; i < z; i++) {
547 a[i] *= b[i] / (T)z;
548 }
549 reverse(a.begin() + 1, a.end());
550 fft(a);
551 a.resize(n + m - 1);
552 return a;
553}
554} // namespace fft
555} // namespace internal::poly
558 fft, // faster than fft_sqrt but less precision
559 fft_sqrt, // uses sqrt to increase precision
561};
562namespace internal::type_traits {
563template <typename T, PolySetting poly_setting>
564struct is_valid_setting : false_type {};
565// ntt allowes ModInt or integral types
566template <ModInt T> struct is_valid_setting<T, PolySetting::ntt> : true_type {};
567template <integral T>
568 requires internal::type_traits::is_64bit_or_less_v<T>
569struct is_valid_setting<T, PolySetting::ntt> : true_type {};
570// fft_sqrt allow broadly_integral types or ModInt
571template <internal::concepts::broadly_integral T>
572struct is_valid_setting<T, PolySetting::fft_sqrt> : true_type {};
573template <ModInt T>
574struct is_valid_setting<T, PolySetting::fft_sqrt> : true_type {};
575// fft allow arithmetic types
576template <typename T>
577 requires is_arithmetic_v<T>
578struct is_valid_setting<T, PolySetting::fft> : true_type {};
579// fft_complex allows complex numbers
580template <typename T>
581 requires is_floating_point_v<T>
582struct is_valid_setting<complex<T>, PolySetting::fft_complex> : true_type {};
583
584template <typename T, PolySetting poly_setting>
586} // namespace internal::type_traits
587
588template <
589 typename T, PolySetting poly_setting,
590 enable_if_t<internal::type_traits::is_valid_setting_v<T, poly_setting>,
591 nullptr_t> = nullptr>
592struct Poly {
593 constexpr Poly() : v(1, 0) {}
594 constexpr Poly(int n) : v(n) {}
595 constexpr Poly(vector<T> v) : v(v) {}
596
597 constexpr int degree() const { return v.size() - 1; }
598 constexpr T operator[](int i) const { return v[i]; }
599 constexpr T &operator[](int i) { return v[i]; }
600
601 constexpr Poly &operator*=(const Poly &o) {
602 if constexpr (poly_setting == PolySetting::ntt) {
603 if constexpr (ModInt<T>) {
604 int z = bit_ceil((unsigned int)(degree() + o.degree() + 1));
605 if (StaticModInt<T> && T::is_prime_mod && (T::imod() - 1) % z == 0) {
607 } else {
609 }
610 } else { // integral
611 vector<long long> a(degree() + 1), b(o.degree() + 1);
612 for (int i = 0; i <= degree(); i++) {
613 a[i] = v[i];
614 }
615 for (int i = 0; i <= o.degree(); i++) {
616 b[i] = o.v[i];
617 }
618 vector<long long> res = internal::poly::ntt::convolution_ll(a, b);
619 v.resize(res.size());
620 for (size_t i = 0; i < res.size(); i++) {
621 v[i] = res[i];
622 }
623 }
624 } else if constexpr (poly_setting == PolySetting::fft_sqrt) {
625 if constexpr (ModInt<T>) {
627 } else { // integral
629 }
630 } else if constexpr (poly_setting == PolySetting::fft) {
632 } else if constexpr (poly_setting == PolySetting::fft_complex) {
634 }
635 return *this;
636 }
637 constexpr Poly operator*(const Poly &o) const {
638 Poly res = *this;
639 return res *= o;
640 }
641
642private:
643 vector<T> v;
644};
645} // namespace maomao90
Definition modint.hpp:21
Definition modint.hpp:23
constexpr int primitive_root
Definition math.hpp:61
void fft(vector< complex< T > > &a)
Definition poly.hpp:390
vector< complex< T > > convolution_complex(const vector< complex< T > > &a, const vector< complex< T > > &b)
Definition poly.hpp:532
vector< T > convolution_sqrt(const vector< T > &a, const vector< T > &b)
Definition poly.hpp:465
vector< mint > convolution_arb_mod(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:497
vector< T > convolution(const vector< T > &ta, const vector< T > &tb)
Definition poly.hpp:425
Definition poly.hpp:37
vector< long long > convolution_ll(const vector< long long > &a, const vector< long long > &b)
Definition poly.hpp:255
vector< mint > convolution_arb_mod(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:333
void butterfly(vector< mint > &a)
Definition poly.hpp:80
vector< mint > convolution_ntt(vector< mint > a, vector< mint > b)
Definition poly.hpp:197
void butterfly_inv(vector< mint > &a)
Definition poly.hpp:134
vector< mint > convolution(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:215
Definition poly.hpp:17
vector< T > convolution_naive(const vector< T > &a, const vector< T > &b)
Definition poly.hpp:19
constexpr bool is_valid_setting_v
Definition poly.hpp:585
Definition hashmap.hpp:8
PolySetting
Definition poly.hpp:556
@ fft_complex
Definition poly.hpp:560
@ fft_sqrt
Definition poly.hpp:559
@ ntt
Definition poly.hpp:557
@ fft
Definition poly.hpp:558
constexpr T inv_gcd(T x, T mod)
Definition extended_gcd.hpp:12
constexpr Poly & operator*=(const Poly &o)
Definition poly.hpp:601
constexpr Poly()
Definition poly.hpp:593
constexpr T operator[](int i) const
Definition poly.hpp:598
constexpr Poly(int n)
Definition poly.hpp:594
constexpr T & operator[](int i)
Definition poly.hpp:599
constexpr Poly operator*(const Poly &o) const
Definition poly.hpp:637
constexpr Poly(vector< T > v)
Definition poly.hpp:595
constexpr int degree() const
Definition poly.hpp:597
array< mint, max(0, rank2 - 2+1)> irate2
Definition poly.hpp:46
static constexpr int rank2
Definition poly.hpp:41
array< mint, rank2+1 > root
Definition poly.hpp:42
array< mint, max(0, rank2 - 3+1)> rate3
Definition poly.hpp:48
array< mint, max(0, rank2 - 3+1)> irate3
Definition poly.hpp:49
array< mint, max(0, rank2 - 2+1)> rate2
Definition poly.hpp:45
array< mint, rank2+1 > iroot
Definition poly.hpp:43
Definition modint.hpp:28