20 int n = a.size(), m = b.size();
21 vector<T> ans(n + m - 1);
23 for (
int j = 0; j < m; j++) {
24 for (
int i = 0; i < n; i++) {
25 ans[i + j] += a[i] * b[j];
29 for (
int i = 0; i < n; i++) {
30 for (
int j = 0; j < m; j++) {
31 ans[i + j] += a[i] * b[j];
41 static constexpr int rank2 = __builtin_ctz(mint::umod() - 1);
42 array<mint, rank2 + 1>
root;
54 for (
int i =
rank2 - 1; i >= 0; i--) {
60 mint prod = 1, iprod = 1;
61 for (
int i = 0; i <=
rank2 - 2; i++) {
69 mint prod = 1, iprod = 1;
70 for (
int i = 0; i <=
rank2 - 3; i++) {
80template <StaticModInt m
int>
void butterfly(vector<mint> &a) {
82 int h = __builtin_ctz((
unsigned int)n);
89 int p = 1 << (h - len - 1);
91 for (
int s = 0; s < (1 << len); s++) {
92 int offset = s << (h - len);
93 for (
int i = 0; i < p; i++) {
94 auto l = a[i + offset];
95 auto r = a[i + offset + p] * rot;
96 a[i + offset] = l + r;
97 a[i + offset + p] = l - r;
99 if (s + 1 != (1 << len)) {
100 rot *= info.
rate2[countr_zero(~(
unsigned int)(s))];
106 int p = 1 << (h - len - 2);
107 mint rot = 1, imag = info.
root[2];
108 for (
int s = 0; s < (1 << len); s++) {
109 mint rot2 = rot * rot;
110 mint rot3 = rot2 * rot;
111 int offset = s << (h - len);
112 for (
int i = 0; i < p; i++) {
113 auto mod2 = 1ull * mint::imod() * mint::imod();
114 auto a0 = 1ull * a[i + offset].val();
115 auto a1 = 1ull * a[i + offset + p].val() * rot.val();
116 auto a2 = 1ull * a[i + offset + 2 * p].val() * rot2.val();
117 auto a3 = 1ull * a[i + offset + 3 * p].val() * rot3.val();
118 auto a1na3imag = 1ull * mint(a1 + mod2 - a3).val() * imag.val();
119 auto na2 = mod2 - a2;
120 a[i + offset] = a0 + a2 + a1 + a3;
121 a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
122 a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
123 a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
125 if (s + 1 != (1 << len)) {
126 rot *= info.
rate3[countr_zero(~(
unsigned int)(s))];
136 int h = __builtin_ctz((
unsigned int)n);
143 int p = 1 << (h - len);
145 for (
int s = 0; s < (1 << (len - 1)); s++) {
146 int offset = s << (h - len + 1);
147 for (
int i = 0; i < p; i++) {
148 auto l = a[i + offset];
149 auto r = a[i + offset + p];
150 a[i + offset] = l + r;
152 (
unsigned long long)(mint::imod() + l.val() - r.val()) *
156 if (s + 1 != (1 << (len - 1))) {
157 irot *= info.
irate2[countr_zero(~(
unsigned int)(s))];
163 int p = 1 << (h - len);
164 mint irot = 1, iimag = info.
iroot[2];
165 for (
int s = 0; s < (1 << (len - 2)); s++) {
166 mint irot2 = irot * irot;
167 mint irot3 = irot2 * irot;
168 int offset = s << (h - len + 2);
169 for (
int i = 0; i < p; i++) {
170 auto a0 = 1ull * a[i + offset + 0 * p].val();
171 auto a1 = 1ull * a[i + offset + 1 * p].val();
172 auto a2 = 1ull * a[i + offset + 2 * p].val();
173 auto a3 = 1ull * a[i + offset + 3 * p].val();
176 1ull * mint((mint::imod() + a2 - a3) * iimag.val()).val();
178 a[i + offset] = a0 + a1 + a2 + a3;
179 a[i + offset + 1 * p] =
180 (a0 + (mint::imod() - a1) + a2na3iimag) * irot.val();
181 a[i + offset + 2 * p] =
182 (a0 + a1 + (mint::imod() - a2) + (mint::imod() - a3)) *
184 a[i + offset + 3 * p] =
185 (a0 + (mint::imod() - a1) + (mint::imod() - a2na3iimag)) *
188 if (s + 1 != (1 << (len - 2))) {
189 irot *= info.
irate3[countr_zero(~(
unsigned int)(s))];
196template <StaticModInt m
int>
198 int n = a.size(), m = b.size();
199 int z = bit_ceil((
unsigned int)(n + m - 1));
204 for (
int i = 0; i < z; i++) {
209 mint iz = mint(z).inv();
210 for (
int i = 0; i < n + m - 1; i++)
214template <StaticModInt m
int>
215vector<mint>
convolution(
const vector<mint> &a,
const vector<mint> &b) {
216 int n = a.size(), m = b.size();
221 int z = bit_ceil((
unsigned int)(n + m - 1));
222 assert((mint::imod() - 1) % z == 0);
224 if (min(n, m) <= 60) {
229template <
int mod,
class T>
230 requires is_integral_v<T>
232 int n = a.size(), m = b.size();
238 int z = bit_ceil((
unsigned int)(n + m - 1));
239 assert((mint::imod() - 1) % z == 0);
241 vector<mint> a2(n), b2(m);
242 for (
int i = 0; i < n; i++) {
245 for (
int i = 0; i < m; i++) {
248 auto c2 =
convolution(std::move(a2), std::move(b2));
249 vector<T> c(n + m - 1);
250 for (
int i = 0; i < n + m - 1; i++) {
256 const vector<long long> &b) {
257 int n = a.size(), m = b.size();
262 if (min(n, m) <= 60) {
266 static constexpr unsigned long long MOD1 = 754974721;
267 static constexpr unsigned long long MOD2 = 167772161;
268 static constexpr unsigned long long MOD3 = 469762049;
269 static constexpr unsigned long long M2M3 = MOD2 * MOD3;
270 static constexpr unsigned long long M1M3 = MOD1 * MOD3;
271 static constexpr unsigned long long M1M2 = MOD1 * MOD2;
272 static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3;
274 static constexpr unsigned long long i1 =
276 static constexpr unsigned long long i2 =
278 static constexpr unsigned long long i3 =
281 static constexpr int MAX_AB_BIT = 24;
282 static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1,
283 "MOD1 isn't enough to support an array length of 2^24.");
284 static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1,
285 "MOD2 isn't enough to support an array length of 2^24.");
286 static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1,
287 "MOD3 isn't enough to support an array length of 2^24.");
288 assert(n + m - 1 <= (1 << MAX_AB_BIT));
294 vector<long long> c(n + m - 1);
295 for (
int i = 0; i < n + m - 1; i++) {
296 unsigned long long x = 0;
297 x += (c1[i] * i1) % MOD1 * M2M3;
298 x += (c2[i] * i2) % MOD2 * M1M3;
299 x += (c3[i] * i3) % MOD3 * M1M2;
317 long long _x = (
long long)x % (
long long)MOD1;
321 long long diff = c1[i] - _x;
324 static constexpr unsigned long long offset[5] = {0, 0, M1M2M3, 2 * M1M2M3,
326 x -= offset[diff % 5];
332template <ModInt m
int>
334 const vector<mint> &b) {
335 int n = a.size(), m = b.size();
340 if (min(n, m) <= 60) {
344 static constexpr long long MOD1 = 167772161;
345 static constexpr long long MOD2 = 469762049;
346 static constexpr long long MOD3 = 754974721;
348 static constexpr long long INV12 =
inv_gcd(MOD1, MOD2);
349 static constexpr long long INV13 =
inv_gcd(MOD1, MOD3);
350 static constexpr long long INV23 =
inv_gcd(MOD2, MOD3);
351 static constexpr long long INV13INV23 = INV13 * INV23 % MOD3;
352 static constexpr long long W1 = MOD1 % mint::imod();
353 static constexpr long long W2 = W1 * MOD2 % mint::imod();
355 static constexpr int MAX_AB_BIT = 24;
356 static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1,
357 "MOD1 isn't enough to support an array length of 2^24.");
358 static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1,
359 "MOD2 isn't enough to support an array length of 2^24.");
360 static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1,
361 "MOD3 isn't enough to support an array length of 2^24.");
362 assert(n + m - 1 <= (1 << MAX_AB_BIT));
364 vector<long long> _a(n), _b(m);
365 for (
int i = 0; i < n; i++) {
368 for (
int i = 0; i < m; i++) {
376 vector<mint> c(n + m - 1);
377 for (
int i = 0; i < n + m - 1; i++) {
378 long long x = (c2[i] + MOD2 - c1[i]) * INV12 % MOD2;
380 ((c3[i] + MOD3 - c1[i]) * INV13INV23 + (MOD3 - x) * INV23) % MOD3;
381 c[i] = c1[i] + x * W1 + y * W2;
389 requires is_floating_point_v<T>
390inline void fft(vector<complex<T>> &a) {
391 int n = a.size(), L = 31 - __builtin_clz(n);
392 static vector<complex<long double>> R(2, 1);
393 static vector<complex<T>> rt(2, 1);
394 for (
static int k = 2; k < n; k *= 2) {
397 auto x = polar(1.0L, acos(-1.0L) / k);
398 for (
int i = k; i < 2 * k; i++) {
399 rt[i] = R[i] = i & 1 ? R[i / 2] * x : R[i / 2];
403 for (
int i = 0; i < n; i++) {
404 rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
406 for (
int i = 0; i < n; i++) {
408 swap(a[i], a[rev[i]]);
411 for (
int k = 1; k < n; k *= 2) {
412 for (
int i = 0; i < n; i += 2 * k) {
413 for (
int j = 0; j < k; j++) {
415 auto x = (T *)&rt[j + k], y = (T *)&a[i + j + k];
416 complex<T> z(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]);
417 a[i + j + k] = a[i + j] - z;
424 requires is_arithmetic_v<T>
425inline vector<T>
convolution(
const vector<T> &ta,
const vector<T> &tb) {
426 int n = ta.size(), m = tb.size();
430 if (min(n, m) <= 60) {
433 vector<double> a(n), b(m);
434 for (
int i = 0; i < n; i++) {
437 for (
int i = 0; i < m; i++) {
440 int z = bit_ceil((
unsigned int)(n + m - 1));
441 vector<complex<double>> in(z), out(z);
442 copy(a.begin(), a.end(), in.begin());
443 for (
int i = 0; i < m; i++) {
447 for (complex<double> &x : in) {
450 for (
int i = 0; i < z; i++) {
451 out[i] = in[-i & (z - 1)] - conj(in[i]);
454 vector<T> res(n + m - 1);
455 for (
int i = 0; i < n + m - 1; i++) {
456 if constexpr (integral<T>) {
457 res[i] = imag(out[i]) / (4 * z) + 0.5;
459 res[i] = imag(out[i]) / (4 * z);
464template <concepts::broadly_
integral T>
466 int n = a.size(), m = b.size();
470 if (min(n, m) <= 60) {
473 int z = bit_ceil((
unsigned int)(n + m - 1)), cut = 1 << 15;
474 vector<complex<double>> L(z), R(z), outs(z), outl(z);
475 for (
int i = 0; i < n; i++) {
476 L[i] = complex<double>(a[i] >> 15, a[i] & ((1 << 15) - 1));
478 for (
int i = 0; i < m; i++) {
479 R[i] = complex<double>(b[i] >> 15, b[i] & ((1 << 15) - 1));
482 for (
int i = 0; i < z; i++) {
483 int j = -i & (z - 1);
484 outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * z);
485 outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * z) / 1i;
488 vector<T> res(n + m - 1);
489 for (
int i = 0; i < n + m - 1; i++) {
490 T av = (T)(real(outl[i]) + .5), cv = (T)(imag(outs[i]) + .5);
491 T bv = (T)(imag(outl[i]) + .5) + (T)(real(outs[i]) + .5);
492 res[i] = (av * cut + bv) * cut + cv;
496template <ModInt m
int>
498 int n = a.size(), m = b.size();
502 if (min(n, m) <= 60) {
505 int z = bit_ceil((
unsigned int)(n + m - 1)), cut = int(sqrt(mint::imod()));
506 vector<complex<double>> L(z), R(z), outs(z), outl(z);
507 for (
int i = 0; i < n; i++) {
508 L[i] = complex<double>(a[i].val() / cut, a[i].val() % cut);
510 for (
int i = 0; i < m; i++) {
511 R[i] = complex<double>(b[i].val() / cut, b[i].val() % cut);
514 for (
int i = 0; i < z; i++) {
515 int j = -i & (z - 1);
516 outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * z);
517 outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * z) / 1i;
520 vector<mint> res(n + m - 1);
521 for (
int i = 0; i < n + m - 1; i++) {
522 long long av = (
long long)(real(outl[i]) + .5),
523 cv = (
long long)(imag(outs[i]) + .5);
525 (
long long)(imag(outl[i]) + .5) + (
long long)(real(outs[i]) + .5);
526 res[i] = (av % mint::imod() * cut + bv) % mint::imod() * cut + cv;
531 requires is_floating_point_v<T>
533 const vector<complex<T>> &b) {
534 int n = a.size(), m = b.size();
538 if (min(n, m) <= 60) {
541 int z = bit_ceil((
unsigned int)(n + m - 1));
546 for (
int i = 0; i < z; i++) {
549 reverse(a.begin() + 1, a.end());
562namespace internal::type_traits {
563template <
typename T, PolySetting poly_setting>
568 requires internal::type_traits::is_64bit_or_less_v<T>
571template <
internal::concepts::broadly_
integral T>
577 requires is_arithmetic_v<T>
581 requires is_floating_point_v<T>
584template <
typename T, PolySetting poly_setting>
590 enable_if_t<internal::type_traits::is_valid_setting_v<T, poly_setting>,
591 nullptr_t> =
nullptr>
594 constexpr Poly(
int n) : v(n) {}
595 constexpr Poly(vector<T> v) : v(v) {}
597 constexpr int degree()
const {
return v.size() - 1; }
604 int z = bit_ceil((
unsigned int)(
degree() + o.
degree() + 1));
612 for (
int i = 0; i <=
degree(); i++) {
615 for (
int i = 0; i <= o.
degree(); i++) {
619 v.resize(res.size());
620 for (
size_t i = 0; i < res.size(); i++) {
constexpr int primitive_root
Definition math.hpp:61
void fft(vector< complex< T > > &a)
Definition poly.hpp:390
vector< complex< T > > convolution_complex(const vector< complex< T > > &a, const vector< complex< T > > &b)
Definition poly.hpp:532
vector< T > convolution_sqrt(const vector< T > &a, const vector< T > &b)
Definition poly.hpp:465
vector< mint > convolution_arb_mod(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:497
vector< T > convolution(const vector< T > &ta, const vector< T > &tb)
Definition poly.hpp:425
vector< long long > convolution_ll(const vector< long long > &a, const vector< long long > &b)
Definition poly.hpp:255
vector< mint > convolution_arb_mod(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:333
void butterfly(vector< mint > &a)
Definition poly.hpp:80
vector< mint > convolution_ntt(vector< mint > a, vector< mint > b)
Definition poly.hpp:197
void butterfly_inv(vector< mint > &a)
Definition poly.hpp:134
vector< mint > convolution(const vector< mint > &a, const vector< mint > &b)
Definition poly.hpp:215
vector< T > convolution_naive(const vector< T > &a, const vector< T > &b)
Definition poly.hpp:19
constexpr bool is_valid_setting_v
Definition poly.hpp:585
PolySetting
Definition poly.hpp:556
@ fft_complex
Definition poly.hpp:560
@ fft_sqrt
Definition poly.hpp:559
@ ntt
Definition poly.hpp:557
@ fft
Definition poly.hpp:558
constexpr T inv_gcd(T x, T mod)
Definition extended_gcd.hpp:12
constexpr Poly & operator*=(const Poly &o)
Definition poly.hpp:601
constexpr Poly()
Definition poly.hpp:593
constexpr T operator[](int i) const
Definition poly.hpp:598
constexpr Poly(int n)
Definition poly.hpp:594
constexpr T & operator[](int i)
Definition poly.hpp:599
constexpr Poly operator*(const Poly &o) const
Definition poly.hpp:637
constexpr Poly(vector< T > v)
Definition poly.hpp:595
constexpr int degree() const
Definition poly.hpp:597
array< mint, max(0, rank2 - 2+1)> irate2
Definition poly.hpp:46
ntt_info()
Definition poly.hpp:51
static constexpr int rank2
Definition poly.hpp:41
array< mint, rank2+1 > root
Definition poly.hpp:42
array< mint, max(0, rank2 - 3+1)> rate3
Definition poly.hpp:48
array< mint, max(0, rank2 - 3+1)> irate3
Definition poly.hpp:49
array< mint, max(0, rank2 - 2+1)> rate2
Definition poly.hpp:45
array< mint, rank2+1 > iroot
Definition poly.hpp:43