题解是 GPT 写的。

1. 规则抽象成集合递推

设模数为 MM

  • 第一天被感染的集合记为

    B{0,1,,M1}B \subseteq \{0,1,\dots,M-1\}
  • dd 天被感染的集合记为 SdS_d

题意说:某人 pp 在某天会被感染,当且仅当存在

  • aSd1a \in S_{d-1}(前一天感染过)

  • bBb \in B(第一天感染过)
    使得

a×bp(modM)a\times b \equiv p \pmod M

因此有递推:

Sd=Sd1  BS_d = S_{d-1}\ \otimes\ B

其中 \otimes 表示“集合乘法”:

$$A\otimes C = \{(x\cdot y)\bmod M \mid x\in A,\ y\in C\}$$

又因为 S1=BS_1 = B,所以

$$S_2 = B\otimes B,\quad S_3 = (B\otimes B)\otimes B = B^3,\ \dots$$

最终得到:

SK=BKS_K = B^K

这里 BKB^K 表示在模 MM 下取 KK 次集合乘法(允许重复选同一个元素,符合“可以重复感染”的含义)。


2. 关键:如何计算集合乘法 ACA\otimes C

用一个长度为 MM 的布尔数组表示集合(0/1):

  • A[i]=1 表示 iAi\in A

  • C[j]=1 表示 jCj\in C

D[(i*j)%M] = 1  当且仅当存在 i∈A, j∈C

朴素做法就是两层枚举 i,ji,j,复杂度 O(M2)O(M^2)。因为 M1500M\le 150015002=2.25×1061500^2=2.25\times 10^6,是可以接受的。


3. 由于 K 很大:对集合做快速幂

我们要算 BKB^K。这是“幂”的结构,所以可以用二进制快速幂(Exponentiation by Squaring)。

关键点是需要一个“乘法单位元”集合:

  • 在模乘法下单位元是数字 11

  • 所以单位集合是 {1}\{1\}

令:

  • R = {1}(答案集合,初始为单位元)

  • Base = B

KK 的二进制位从高到低(或从低到高)迭代:

  • 每一步先做:R = R ⊗ R(相当于指数翻倍/平方)

  • 如果当前位为 1:R = R ⊗ Base

最后得到的 R 就是 BKB^K,也就是第 KK 天会感染的所有编号。

你给的代码正是这个过程:

  • ans 保存当前集合 R

  • 先平方(两层循环用 ans[org]ans[org]

  • 如果该位在 k 里为 1,再乘一次 baza(即 B


4. 正确性简述

  • 递推部分:由题目定义直接推出 Sd=Sd1BS_d = S_{d-1}\otimes B,归纳可得 SK=BKS_K = B^K

  • 快速幂部分:集合乘法满足结合律(本质来自整数乘法 mod MM 的结合律),因此二进制快速幂完全成立;从单位元 {1}\{1\} 出发,按位累乘能得到精确的 BKB^K


5. 复杂度

  • 集合乘法一次:O(M2)O(M^2)

  • 快速幂需要约 log2K60\log_2 K \le 60 次平方 + 若干次乘底

  • 总复杂度:

    O(M2logK)O(M^2\log K)

    M=1500M=1500 时大约是 2.25×106×601.35×1082.25\times 10^6 \times 60 \approx 1.35\times 10^8 级别,C++ 里可行。

  • 空间:O(M)O(M)


6. 输出

最终集合里为 1 的下标按从小到大输出即可。若 N=0N=0,则 BB 为空,任何乘法都得空集,输出空行也符合题意。

#include <bits/stdc++.h>
using namespace std;
// Time complexity: O(max_mod * max_mod * log(k))
const int max_mod = 1500;
char baza[max_mod];
char ans[2][max_mod], org = 0;
int mod;
long long k;
void load() {
  int n, t;
  scanf("%lld%d%d", &k, &mod, &n);
  for (int j = 0; j < n; ++j)
    scanf("%d", &t) && (baza[t] = 1);
}
void solve() {
  ans[org][1] = 1;
  for (unsigned long long mask = (1ll<<60); mask; mask >>=1) {
    memset(ans[!org], 0, mod);
    for (int i = 0; i < mod; ++i)
      for (int j = 0; j < mod; ++j)
        if (ans[org][i] && ans[org][j])
          ans[!org][ (i*j) % mod ] = 1;
    org = !org;
    if (mask & k) {
      memset(ans[!org], 0, mod);
      for (int i = 0; i < mod; ++i)
        for (int j = 0; j < mod; ++j)
          if (ans[org][i] && baza[j])
            ans[!org][ (i*j) % mod ] = 1;
      org = !org;
    }
  }
  for (int j = 0; j < mod; ++j) {
    if (ans[org][j])
      printf("%d ", j);
  }
  printf("\n");
}

int main() {
  load();
  solve();
  return 0;
}