D - 统计子矩阵 (双指针+前缀和+降维处理)
1、问题
D - 统计子矩阵
2、分析 + 代码
(1)纯暴力做法:
这个做法就很简单了,我们直接枚举所有的子矩阵,然后在对每一个子矩阵内部的元素逐一累加起来计算比较即可。这里就不写代码了,直接从第二种方法开始。
(2)暴力+二维前缀和:
这里和方法一是一致的,依旧是去枚举所有的子区间,在计算子区间的和的时候,我们使用二维前缀和的算法进行优化,从而将计算矩阵和的算法的复杂度从
O
(
n
)
O(n)
O(n)降低到
O
(
1
)
O(1)
O(1)。
此时我们能过百分之七八十的数据,但无法过所有的数据,因为此时的时间复杂度是
O
(
n
4
)
O(n^4)
O(n4)的,依旧很高。
#include<bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
#define int long long
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 5e2 + 10;
int a[N][N], s[N][N];
int n, m, k;
int ans;
int cal(int x1, int y1, int x2, int y2)
{
int res = s[x2][y2] - s[x1 - 1][y2] - s[x2][y1 - 1] + s[x1 - 1][y1 - 1];
return res;
}
void solve()
{
cin >> n >> m >> k;
for(int i = 1; i <= n; i ++ )
{
for(int j = 1; j <= m; j ++ )
{
cin >> a[i][j];
s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + a[i][j];
}
}
for(int len_r = 1; len_r <= n; len_r ++)
{
for(int len_c = 1; len_c <= m; len_c ++)
{
for(int i = 1; i + len_r - 1 <= n; i ++ )
{
int x1 = i, x2 = i + len_r - 1;
for(int j = 1; j + len_c - 1 <= m; j ++)
{
int y1 = j, y2 = j + len_c - 1;
//cout << x1 << " " << y1 << " " << x2 << " " << y2 << endl;
if(cal(x1, y1, x2, y2) <= k)
ans ++;
}
}
}
}
cout << ans << endl;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
solve();
}
当然,我们还可以进行一些优化,比如我们去标记一下那些大于 k k k的矩阵,这样的话,当后续的较大矩阵包含标记矩阵的时候,肯定也是不符合条件的,这样就可以不用去枚举了。如果数据不强的话,这个方法可以勉强过。这里就不做展示了。
(3)降维打击+双指针:
现在开始介绍这道题的正解:我们可以将二维转化成1维,转化方式如下:
对于每一个矩阵而言,我们将列都加在一起,这样就能看成一个一维的区间。
此时我们就可以准备两条线横线,两条横线之间的部分就是我们需要讨论的一维区间。如下图所示:
去枚举这两条线的时间复杂度是
O
(
N
2
)
O(N^2)
O(N2)的。
那么对于转换后的一维区间我们该如何讨论呢?此时我们使用的算法是:双指针。
我们下面直接针对转化后的一维数组进行讨论。
我们定义两个指针:
l
l
l和
r
r
r,我们先固定
l
l
l指针,然后让
r
r
r指针向右移动,一遍移动一遍记录扫过的元素的和,当我们的和大于
k
k
k的时候,我们的
r
r
r指针停止扫描。此时,我们以
l
l
l为左端点的合法区间就是
(
r
−
l
+
1
)
(r-l+1)
(r−l+1)。此时,我们就让
l
l
l向右移动,直到区间内的元素和小于等于
k
k
k为止。然后再让
r
r
r继续移动,重复上述操作。
该过程由于只扫描了一遍,所以时间复杂度是
O
(
n
)
O(n)
O(n)的。
那么总的时间复杂度就是
O
(
n
3
)
O(n^3)
O(n3)的。
#include<bits/stdc++.h>
#define endl '\n'
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int N = 500 + 10;
ll a[N][N];
ll s[N][N];
ll ans;
int n, m ,k;
void solve()
{
cin >> n >> m >> k;
for(int i = 1; i <= n; i ++ )
for(int j = 1; j <= m; j ++ )
cin >> a[i][j];
for(int j = 1; j <= m; j ++ )
for(int i = 1; i <= n; i ++)
s[i][j]=s[i-1][j]+a[i][j];
// for(int i = 1; i <= m; i ++ )
// {
// cout << s[i] << " ";
// }
// cout << endl;
for(int l1 = 1; l1 <= n; l1 ++ )
{
for(int l2 = l1; l2 <= n; l2 ++ )
{
vector<int>col(m + 1);
for(int i = 1; i <= m; i ++ )
col[i] = s[l2][i] - s[l1 - 1][i];
ll sum = 0, l = 1;
for(int r = 1; r <= m; r ++ )
{
sum += col[r];
if(sum <= k)
ans += (r - l + 1);
else
{
while(sum > k)
{
sum -= col[l];
l ++;
}
ans += r - l + 1;
}
}
}
}
cout << ans << endl;
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
solve();
}