给你一个整数数组 nums
和一个二维数组 queries
,其中 queries[i] = [posi, xi]
。
对于每个查询 i
,首先将 nums[posi]
设置为 xi
,然后计算查询 i
的答案,该答案为 nums
中 不包含相邻元素 的 子序列的 最大 和。
返回所有查询的答案之和。
由于最终答案可能非常大,返回其对 109 + 7
取余 的结果。
今天的题目确实有点难度,一眼看出来应该是个线段树,但线段树的节点该怎么存却卡了很久。只要能把节点想出来就是个板子题,但想不想的出来就不好说了。
本来我是想保存node节点对应区间的最大和的起始和结束位置的,但发现不行,因为这样存的话,node的父节点的答案还是需要一次O(n)遍历才能出结果。
最后我每个node保存了4个元素,cc,co,oc,oo,分别表示该节点对应区间在左右闭区间,左闭右开,左开右闭和左右开区间的情况下的最大和。cc表示左右端点都包括在内,但不一定被选了,而oo表示左右端点都不包括,也就是一定都没被选。其实o和c就是open和close。
那么,根据左右子节点的node数据要怎么计算出该节点的答案呢?
我们以cc的情况为例。父节点是左右闭,那左节点一定是左闭,右节点一定是右闭。同时,因为要求没有连续元素,所以左节点右闭和右节点左闭最多只能有一个存在,所以父节点的cc值应该是max(L.cc+R.oc, L.co+R.oc, L.co+R.cc)。
实际上,因为L.cc是一定大于等于L.co的,所以上面的式子可以不考虑L.co情况,也就是变成
max(L.cc+R.oc, L.co+R.cc)。
另外三个元素同理,具体请看代码。
#define ll long long
const int modx = 1e9+7;
int* a=NULL;
struct node {
ll cc, co, oc, oo;
};
ll max(ll a, ll b)
{
return a>b ? a : b;
}
void build(int pos, int l, int r, struct node* tree)
{
if (l == r)
{
tree[pos] = (struct node){max(a[l], 0), 0, 0, 0};
return;
}
int mid = (l+r)/2;
build(pos*2, l, mid, tree);
build(pos*2+1, mid+1, r, tree);
tree[pos].cc = max(tree[pos*2].cc+tree[pos*2+1].oc, tree[pos*2].co+tree[pos*2+1].cc);
tree[pos].co = max(tree[pos*2].co+tree[pos*2+1].co, tree[pos*2].cc+tree[pos*2+1].oo);
tree[pos].oo = max(tree[pos*2].oc+tree[pos*2+1].oo, tree[pos*2].oo+tree[pos*2+1].co);
tree[pos].oc = max(tree[pos*2].oc+tree[pos*2+1].oc, tree[pos*2].oo+tree[pos*2+1].cc);
}
void update(int pos, int value, int tar, int l, int r, struct node* tree)
{
if (l == r)
{
tree[pos] = (struct node){max(value, 0), 0, 0, 0};
return;
}
int mid = (l+r)/2;
if (tar <= mid) update(pos*2, value, tar, l, mid, tree);
else update(pos*2+1, value, tar, mid+1, r, tree);
tree[pos].cc = max(tree[pos*2].cc+tree[pos*2+1].oc, tree[pos*2].co+tree[pos*2+1].cc);
tree[pos].co = max(tree[pos*2].co+tree[pos*2+1].co, tree[pos*2].cc+tree[pos*2+1].oo);
tree[pos].oo = max(tree[pos*2].oc+tree[pos*2+1].oo, tree[pos*2].oo+tree[pos*2+1].co);
tree[pos].oc = max(tree[pos*2].oc+tree[pos*2+1].oc, tree[pos*2].oo+tree[pos*2+1].cc);
}
int maximumSumSubsequence(int* nums, int numsSize, int** queries, int queriesSize, int* queriesColSize) {
ll ans=0;
a = nums-1;
struct node tree[numsSize*4 + 10];
build(1, 1, numsSize, tree);
for (int i=0; i<queriesSize; ++i)
{
int pos = queries[i][0]+1, value = queries[i][1];
update(1, value, pos, 1, numsSize, tree);
ans += tree[1].cc;
}
return ans%modx;
}
更新节点的时间复杂度是O(logn),所以程序的时间复杂度是O(n+q*logn)。q是q次查询
空间复杂度是O(n)。