D. “a” String Problem
题意
给定一个字符串 s s s,要求把 s s s 拆分成若干段,满足以下要求:
- 拆分出来的每一个子段,要么是子串 t t t,要么是字符 a a a
- 子串 t t t 至少出现一次
- t ≠ " a " t \neq "a " t="a"
问有多少种不同的子串 t t t 满足以上要求
思路
如果 s s s 全是 a a a 的话,假设 ∣ s ∣ = n |s| = n ∣s∣=n,那么答案是: n − 1 n - 1 n−1
否则通过简单观察我们可以发现: t t t 必须包含非 a a a 字符(不一定是所有,只要 t t t 可以覆盖这些非 a a a 字符即可)
假设
s
s
s 从左到右第一个出现非
a
a
a 字符的位置是
p
0
p_0
p0,如果我们先固定
t
t
t 的开头在
p
0
p_0
p0,
我们就可以先枚举
t
t
t 的长度从
1
→
n
−
p
0
+
1
1 \rarr n - p_0 + 1
1→n−p0+1,那么如何确定当前
t
t
t 是否满足上述要求?
我们直接在第一个 t t t 的末尾后面,找第一个出现的非 a a a 字符作为第二个 t t t 的开头,然后这后面 l e n len len 个字符必须与 t t t 相等,如果相等则继续往后检查,否则当前 t t t 无效。
我们只需要预处理每一个位置后面第一个非
a
a
a 字符的位置就可以,倒着扫一遍就可以线性预处理出来。
匹配的过程我们可以使用
Z
Z \;
Z函数,这是由于本质上是从
p
0
p_0
p0 开始的字符串,拿它去和它自己本身的每个后缀做匹配,自然可以使用
Z
Z \;
Z 函数。
那么这个检查过程是:
O
(
n
log
n
)
O(n \log n)
O(nlogn) 的(调和级数复杂度)
那么现在问题在于:
t
t
t 不一定是以非
a
a
a 字符开头的。
其实这个问题很容易处理,假设我们当前有效的从
p
0
p_0
p0 开头的
∣
t
∣
=
l
e
n
|t| = len
∣t∣=len,那么我们在检查过程的同时,记录每个
t
t
t 的前面到前一个
t
t
t 的末尾,有多少个
a
a
a,统计这个最小值
m
n
mn
mn
那么很显然,当前的
t
t
t 可以往前扩展最多
m
n
mn
mn 个
a
a
a,最后还是有效的。
那么这里的方案数就是:
m
n
+
1
mn + 1
mn+1
所以答案最后对于每个有效长度累加即可
时间复杂度: O ( n log n ) O(n \log n) O(nlogn)
#include<bits/stdc++.h>
#define fore(i,l,r) for(int i=(int)(l);i<(int)(r);++i)
#define fi first
#define se second
#define endl '\n'
#define ull unsigned long long
#define ALL(v) v.begin(), v.end()
const int INF=0x3f3f3f3f;
const long long INFLL=0x3f3f3f3f3f3f3f3fLL;
typedef long long ll;
std::vector<int> z_function(const std::string& s, int n){
std::vector<int> z(n + 1, 0);
z[1] = n;
int l = 0, r = 0;
fore(i, 2, n + 1){
if(i <= r) z[i] = std::min(z[i - l + 1], r - i + 1);
while(i + z[i] <= n && s[1 + z[i]] == s[i + z[i]])
++z[i];
if(i + z[i] - 1 > r){
l = i;
r = i + z[i] - 1;
}
}
return z;
}
int main(){
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int t;
std::cin >> t;
while(t--){
std::string s;
std::cin >> s;
int n = s.size();
s = '0' + s;
std::vector<int> nxt_nona(n + 5, n + 5);
for(int i = n; i > 0; --i){
if(s[i] == 'a') nxt_nona[i] = nxt_nona[i + 1];
else nxt_nona[i] = i;
}
if(nxt_nona[1] > n){ //全是a
std::cout << n - 1 << endl;
continue;
}
int p0 = nxt_nona[1];
std::string T = s.substr(p0);
T = '0' + T;
auto z = z_function(T, T.size() - 1);
ll ans = 0;
fore(len, 1, n - p0 + 2){
int mn = p0 - 1;
bool ok = true;
int lst = p0 + len - 1;
for(int j = nxt_nona[p0 + len]; j <= n; j = nxt_nona[j + len]){
if(z[j - p0 + 1] < len){
ok = false;
break;
}
mn = std::min(mn, j - lst - 1);
lst = j + len - 1;
}
if(ok) ans += mn + 1;
}
std::cout << ans << endl;
}
return 0;
}