ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

FFT/NTT字符串模糊匹配

2021-07-30 20:32:10  阅读:263  来源: 互联网

标签:const int s2 s1 FFT NTT 字符串 include sum


因为FFT精度问题太离谱了,所以墙裂推荐用NTT
首先考虑精确匹配:https://www.acwing.com/problem/content/833/
假设我们有短串\(s1\)(长度为\(n\)),长串\(s2\)(长度为\(m\))
我们定义字符差

\[c(x,y) = s1(x) - s2(y) \]

若\(c(x,y) = 0\),表明\(s1\)的第\(x\)个字符与\(s2\)的第\(y\)个字符匹配,再定义

\[F(x) = \sum_{i = 0}^{n - 1}c(i,x-n+i+1) \]

为\(s2\)子串的字符差之和,这个子串长为\(n\)并且以下标\(x\)为结尾,若\(F(x) = 0\),则表明这个子串与\(s1\)完全匹配,但这样可能会将\(ab\)与\(ba\)算为完全匹配,因此我们考虑将\(F(x)\)换个表达式

\[F(x) = \sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2} \]

这样若\(F(x) = 0\),则表明这个子串与之完全匹配,将其暴力拆解

\[F(x) =\sum_{i = 0}^{n - 1}s1(i)^2+\sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-\sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1) \]

其中\(\sum_{i = 0}^{n - 1}s1(i)^2\)和\(\sum_{i = 0}^{n - 1}s2(x-n+i+1)^2\)都可以用前缀和解决,关键是\(\sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)\),我们将\(s1\)翻转,可得\(s1'(x-n+i+1)=s1(i)\),即

\[\sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)=\sum_{i = 0}^{n - 1}2s1'(n-i-1)s2(x-n+i+1)=\sum_{i+j=x}^{}s1'(i)s2(j) \]

可以发现能用NTT啦!因此

\[F(x) = sum - S(x) + S(x-n) - 2\sum_{i+j=x}^{}s1'(i)s2(j) \]

当\(F(x)=0\)时,表明完全匹配

AC代码:
不开O2会T

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
#pragma GCC optimize(2)
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m, tot, bit;
char s1[N], s2[N];
ll S[N], a[N], b[N];
int R[N];
ll ksm(ll a, ll b)
{
    ll res = 1 % mod;
    while (b)
    {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}
void inif(int n)
{
    tot = 1, bit = 0;
    while (tot <= n)
        tot *= 2, ++bit;
    for (int i = 0; i <= tot; ++i)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
    for (int i = 0; i < total; ++i)
        if (i < R[i])
            swap(f[i], f[R[i]]);
    for (int tot = 2; tot <= total; tot *= 2)
    {
        ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
        //332748118为 3 在模 998244353 的逆元
        for (int pos = 0; pos < total; pos += tot)
        {
            ll w = 1;
            for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
            {
                int x = f[i];
                int y = w * f[i + tot / 2] % mod;
                f[i] = (x + y) % mod;
                f[i + tot / 2] = (x - y + mod) % mod;
            }
        }
    }
    if (type == -1)
    {
        int inv = ksm(tot, mod - 2);
        for (int i = 0; i <= n + m; ++i)
            a[i] = a[i] * inv % mod;
    }
}

int main()
{
    scanf("%d%s%d%s", &n, &s1, &m, &s2);
    for (int i = 0; i < n; ++i)
        a[i] = s1[i] - 'a' + 1;
    for (int i = 0; i < m; ++i)
        b[i] = s2[i] - 'a' + 1;
    reverse(a, a + n);
    ll sum = 0;
    for (int i = 0; i < n; ++i)
        sum = (sum + a[i] * a[i] % mod) % mod;
    S[0] = b[0] * b[0];
    for (int i = 1; i < m; ++i)
        S[i] = (S[i - 1] + b[i] * b[i] % mod) % mod;
    inif(n + m);
    NTT(a, tot, 1), NTT(b, tot, 1);
    for (int i = 0; i < tot; ++i)
        a[i] = a[i] * b[i] % mod;
    NTT(a, tot, -1);
    for (int x = n - 1; x < m; ++x)
    {
        double P = (sum + S[x] - S[x - n] - 2 * a[x]) % mod;
        if (P == 0)
            printf("%d ", x - n + 1);
    }
    return 0;
}

接着我们考虑模糊匹配,即有通配符的情况:https://www.luogu.com.cn/problem/P4173
设通配符的值为0,重新定义字符差

\[c(x,y) = [s1(x) - s2(y)]^2s1(x)s2(y) \]

发现会完美解决问题,依然暴力拆解

\[F(x) = \sum_{i = 0}^{n - 1}[s1(i)-s2(x-n+i+1)]^{2}s1(i)s2(x-n+i+1)\\ =[\sum_{i = 0}^{n - 1}s1(i)^2+\sum_{i = 0}^{n - 1}s2(x-n+i+1)^2-\sum_{i = 0}^{n - 1}2s1(i)s2(x-n+i+1)]s1(i)s2(x-n+i+1)\\ =\sum_{i = 0}^{n - 1}s1(i)^3s2(x-n+i+1)+\sum_{i = 0}^{n - 1}s1(i)s2(x-n+i+1)^3-\sum_{i = 0}^{n - 1}2s1(i)^2s2(x-n+i+1)^2\\ =\sum_{i+j=x}^{}s1'(i)^3s2(j)+\sum_{i+j=x}^{}s1'(i)s2(j)^3+\sum_{i+j=x}^{}s1'(i)^2s2(j)^2 \]

当\(F(x)=0\)时,表明完全匹配

AC代码:

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e7 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
int n, m;
int A[N], B[N];
char s1[N], s2[N];
int R[N], ans[N];
int tot, bit, pos;
ll a[N], b[N], p[N];
ll ksm(ll a, ll b)
{
	ll res = 1 % mod;
	while (b)
	{
		if (b & 1)
			res = res * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return res;
}
void inif(int n)
{
	tot = 1, bit = 0;
	while (tot <= n)
		tot *= 2, ++bit;
	for (int i = 0; i <= tot; ++i)
		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
	for (int i = 0; i < total; ++i)
		if (i < R[i])
			swap(f[i], f[R[i]]);
	for (int tot = 2; tot <= total; tot *= 2)
	{
		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
		//332748118为 3 在模 998244353 的逆元
		for (int pos = 0; pos < total; pos += tot)
		{
			ll w = 1;
			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
			{
				int x = f[i];
				int y = w * f[i + tot / 2] % mod;
				f[i] = (x + y) % mod;
				f[i + tot / 2] = (x - y + mod) % mod;
			}
		}
	}
	if (type == -1)
	{
		int inv = ksm(tot, mod - 2);
		for (int i = 0; i <= n + m; ++i)
			a[i] = a[i] * inv % mod;
	}
}
int main()
{
	scanf("%d%d%s%s", &n, &m, &s1, &s2);
	reverse(s1, s1 + n);
	for (int i = 0; i < n; ++i)
		A[i] = s1[i] == '*' ? 0 : s1[i] - 'a' + 1;
	for (int i = 0; i < m; ++i)
		B[i] = s2[i] == '*' ? 0 : s2[i] - 'a' + 1;
	inif(n + m);
	//A[i]^3 B[i]
	for (int i = 0; i < tot; ++i)
		a[i] = A[i] * A[i] * A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] + a[i] * b[i]) % mod;
	//A[i] B[i]^3
	for (int i = 0; i < tot; ++i)
		a[i] = A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i] * B[i] * B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] + a[i] * b[i]) % mod;
	//A[i]^2 B[i]^2
	for (int i = 0; i < tot; ++i)
		a[i] = A[i] * A[i];
	for (int i = 0; i < tot; ++i)
		b[i] = B[i] * B[i];
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
		p[i] = (p[i] - 2 * a[i] * b[i] + mod) % mod;

	NTT(p, tot, -1);
	for (int i = n - 1; i < m; ++i)
		if (p[i] == 0)
			ans[++pos] = i - n + 2;

	printf("%d\n", pos);
	for (int i = 1; i <= pos; ++i)
		printf("%d ", ans[i]);
	return 0;
}

然后是杭电多校让我知道了这个知识点
HDU6975:https://acm.hdu.edu.cn/showproblem.php?pid=6975
因为字符只包含0-9和,首先不考虑通配符,我们可以枚举0-9,将每个子串在0-9情况下的匹配数算出来,以8为例,将所有为8的地方值设为1,其他地方值设为0,则对单个字符的匹配数有

\[F(x)=\sum_{i=0}^{n-1}s1(i)s2(x-n+1+i)=\sum_{i=0}^{n-1}s1(n-i-1)s2(x-n+i+1)=\sum_{i+j=x}s1(i)s2(j) \]

求出每个子串的匹配数后就可以考虑通配符了,其实通配符匹配数=\(s1\)通配符数+\(s2\)子串通配符数-\(s1\)和\(s2\)子串相同位置的通配符数,前缀和加卷积即可求出

AC代码:

#include <unordered_map>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
typedef pair<int, int> PII;
typedef pair<double, int> PDI;
//typedef __int128 int128;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 1e6 + 10, M = 4e7 + 10;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-9;
const double PI = acos(-1.0);
FILE *fp;
int n, m, tot, bit;
char s1[N], s2[N];
int R[N], ans[N];
ll a[N], b[N], f[N], S[N];
ll ksm(ll a, ll b)
{
	ll res = 1 % mod;
	while (b)
	{
		if (b & 1)
			res = res * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return res;
}
void inif(int n)
{
	memset(s1, 0, sizeof(s1));
	memset(s2, 0, sizeof(s2));
	memset(ans, 0, sizeof(ans));
	memset(f, 0, sizeof(f));
	tot = 1, bit = 0;
	while (tot <= n)
		tot *= 2, ++bit;
	for (int i = 0; i <= tot; ++i)
		R[i] = (R[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll f[], int total, int type)
{
	for (int i = 0; i < total; ++i)
		if (i < R[i])
			swap(f[i], f[R[i]]);
	for (int tot = 2; tot <= total; tot *= 2)
	{
		ll w1 = ksm(type == 1 ? 3 : 332748118, (mod - 1) / tot);
		//332748118? 3 ?? 998244353 ???
		for (int pos = 0; pos < total; pos += tot)
		{
			ll w = 1;
			for (int i = pos; i < pos + tot / 2; ++i, w = w * w1 % mod)
			{
				int x = f[i];
				int y = w * f[i + tot / 2] % mod;
				f[i] = (x + y) % mod;
				f[i + tot / 2] = (x - y + mod) % mod;
			}
		}
	}
	if (type == -1)
	{
		int inv = ksm(tot, mod - 2);
		for (int i = 0; i <= n + m; ++i)
			f[i] = f[i] * inv % mod;
	}
}
void get(char c, int type)
{
	for (int i = 0; i < tot; ++i)
		a[i] = s1[i] == c;
	for (int i = 0; i < tot; ++i)
		b[i] = s2[i] == c;
	NTT(a, tot, 1), NTT(b, tot, 1);
	for (int i = 0; i < tot; ++i)
	{
		if (type == 1)
			f[i] = (f[i] + a[i] * b[i] % mod) % mod;
		else
			f[i] = (f[i] - a[i] * b[i] % mod + mod) % mod;
	}
}
int main()
{
	int T;
	scanf("%d", &T);
	while (T--)
	{
		scanf("%d%d", &m, &n);
		inif(n + m);
		scanf("%s%s", s2, s1);
		reverse(s1, s1 + n);

		for (char c = '0'; c <= '9'; ++c)
			get(c, 1);
		get('*', -1);
		NTT(f, tot, -1);
		ll sum = 0;
		for (int i = 0; i < n; ++i)
			sum += s1[i] == '*';
		S[0] = s2[0] == '*';
		for (int i = 1; i < m; ++i)
			S[i] = (S[i - 1] + (s2[i] == '*')) % mod;
		for (int i = 0; i < tot; ++i)
		{
			if (i >= n)
				f[i] = (f[i] + sum + S[i] - S[i - n] + mod) % mod;
			else
				f[i] = (f[i] + sum + S[i]) % mod;
		}
		for (int i = n - 1; i < m; ++i)
			++ans[n - f[i]];
		for (int i = 0; i <= n; ++i)
		{
			if (i)
				ans[i] += ans[i - 1];
			printf("%d\n", ans[i]);
		}
	}
	return 0;
}

标签:const,int,s2,s1,FFT,NTT,字符串,include,sum
来源: https://www.cnblogs.com/xiaopangpangdehome/p/15080759.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有