数论|牛客小白月赛25 J异或和之和——组合数+位运算

牛客小白月赛25 J异或和之和 第一篇博客回顾一下上周的牛客小白月赛。
先来看一下题目题目链接
数论|牛客小白月赛25 J异或和之和——组合数+位运算
文章图片

之前从来没有认识到逆元的重要性,直到被J题卡了快一个小时,赛后问了大佬才意识到自己犯了很严重的错误orz。
先对题目分析一下,如果用一般的方法求出n的所有三元组再进行异或运算,那是必T无疑的,因此我们想到利用位运算的性质,如果三个数连续进行异或,只有两种情况会出1,分别是1 1 1之间的异或和1 0 0 之间的异或,那么我们只需要对每一位进行求组合数算出分别有几种出1的可能,然后依次相加就可以很容易得出结论。最后计算一下时间复杂度,long long 类型最多64位,N * 64 * k的复杂度显然满足要求。
其实这题的组合数比较水,不用逆元也可以过 ,下面来看两种求解的方法吧。
1.不用逆元
直接看代码

#include using namespace std; typedef long long ll; const int N = 200010; const int mod = 1e9 + 7; int c[64], n; ll res; inline ll C1(int x) { if (x >= 3) return 1ll * x * (x - 1) * (x - 2) / 6; else return 0; } inline ll C2(int x) { if (x >= 2) return 1ll * x * (x - 1) / 2; else return 0; } int main() { ios::sync_with_stdio(false); cin >> n; for (int i = 1; i <= n; i++) { ll x; intb = 0; cin >> x; while (x) c[b++] += (x & 1), x >>= 1; } ll bs = 1; for (int i = 0; i < 64; i++) { res = (res + (1ll * C1(c[i]) % mod + 1ll * c[i] * C2(n - c[i]) % mod) * bs) % mod; bs = (bs << 1) % mod; } cout << res; }

2.利用逆元进行求解
利用费马小定理
已知a^(p-1) % p = 1 => a * a ^(p-2) % p = 1(p为质数)
因此a^(p-2)就是我们所求的a的逆元,因此可以用快速幂求。
接着我们来看C(n, m) = n!/(m! * (n-m)!)这个式子,我们需要处理每个数的阶乘,因此我们可以用fac[]数组预先处理一下。
然后将除法转换成n!* inv(n!)
设 m! 的逆元是 N ,那么N = m! ^ (p-2)
设(n - m)! 的逆元是 M ,那么M = (n-m)! ^ (p-2)
那么原式可以转换为n! * N * M % p;
【数论|牛客小白月赛25 J异或和之和——组合数+位运算】下面是ac代码
#include using namespace std; typedef long long ll; const int N = 200010; const ll mod = 1e9 + 7; ll fac[N]; int c[64]; ll n, m; ll quick_pow(ll a, ll b) { a = a % mod; ll res = 1, base = a; while (b) { if (b & 1) res = res * base % mod; base = base * base % mod; b >>= 1; } return res; }void get_fac() { fac[0] = 1; for (int i = 1; i <= 200000; i++) fac[i] = fac[i - 1] * i * 1ll % mod; }ll combine(int n, int m) { if (n < m) return 0; else return fac[n] * quick_pow(fac[m], mod - 2) % mod * quick_pow(fac[n - m], mod - 2) % mod; }int main() { get_fac(); cin >> n; for (int i = 1; i <= n; i++) { ll x; int k = 0; cin >> x; while (x) c[k++] += x & 1, x >>= 1; } ll bs = 1; ll res = 0; for (int i = 0; i < 64; i++) { res = (res + (combine(c[i], 3) % mod + (c[i] * combine(n - c[i], 2) % mod)) * bs) % mod; bs = (bs << 1) % mod; } cout << res; }

    推荐阅读