文章目录
- 分值:200
- 题目描述
- 思路
- 复杂度分析
- AC 代码
分值:200
题目描述
给定 M 个字符(a
-z
),从中取出任意字符(每个字符只能用一次)拼接成长度为 N
的字符串,要求相同的字符不能相邻。
计算出给定的字符列表能拼接出多少种满足条件的字符串,无法拼接出满足条件的字符串则返回 0。
输入描述:
给定长度为 M
的字符列表和结果字符串的长度 N
,中间使用空格
拼接。
输出描述:
输出满足条件的字符串个数。
示例1
输入:
abc 1
输出:
3
解释:
给定的字符为 abc ,结果字符申长度为 1 ,可以拼接成 a、b、c ,共 3 种。
示例1
输入:
qwertyuiopasdfghjklzxcvbnm 5
输出:
7893600
解释:
此用例为比较极限的用例,写完代码后跑一下这个用例看一下大概用时多少,确保在 1s 内跑完。
Tips:
0 < M < 30
0 < N ≤ 5
思路
- 本题考查的是DFS回溯去重。
- 思考1:如果输入的字符串不包含重复字符,该怎么做?全排列
- 思考2:如果输入的字符串包含重复字符,该怎么做?全排列+去重
- 思考3:如何优雅地去重?
首先将字符串 s 转成字符数组 arr,对字符数组排序,排序之后,相同字符位于字符数组中的相邻位置,可以利用这一特点去重。
对于去重操作,需要考虑产生重复排列的原因。如果一个字符在字符数组中出现k
次,则对于任意一个排列,将这k
个字符的相对顺序交换之后会得到与原排列重复的排列,因此,为了避免重复排列,需要确保这k
个字符加入排列的顺序是固定的。具体而言,当i > 0
时,如果arr[i] = arr[i − 1]
(此时字符数组arr
已排序),则需要确保arr[i − 1]
在arr[i]
之前加入排列。
根据上述分析,在排序后的字符数组arr
中遍历到下标i
时,以下两种情况不应将arr[i]
加入当前排列。
1.如果arr[i]
已经加入当前排列,则不能多次加入当前排列。
2.如果当i > 0
时,arr[i] = arr[i − 1]
且arr[i − 1]
未加入当前排列,则不能将arr[i]
加入当前排列,否则arr[i − 1]
将在arr[i]
之后加入当前排列,导致出现重复排列。
复杂度分析
- 时间复杂度:
O
(
N
M
)
O(N^M)
O(NM),其中
N
为输入的字符串的长度,M
为目标字符串的长度。 - 空间复杂度:
O
(
N
)
O(N)
O(N),其中
N
为输入的字符串的长度。
AC 代码
C++ 版
#include <bits/stdc++.h>
using namespace std;
int ans = 0, n, vis[31];
void dfs(string &cur, string &s)
{
// 递归边界
if (cur.size() == n)
{
ans++;
return;
}
for (int i = 0; i < s.size(); i++)
{
// 防止生成重复字符串
if (vis[i] || (i > 0 && s[i] == s[i - 1] && !vis[i - 1]) || (cur.size() > 0 && cur.back() == s[i]))
{
continue;
}
vis[i] = true;
cur += s[i];
dfs(cur, s);
cur.pop_back();
vis[i] = false;
}
}
int main()
{
string str, cur = "";
memset(vis, 0, sizeof(vis));
cin >> str >> n;
sort(str.begin(), str.end());
dfs(cur, str);
cout << ans << endl;
return 0;
}
JAVA 版
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;
public class Main {
static int ans = 0;
static int n;
static int[] vis;
public static void dfs(List<Character> cur, String s) {
if (cur.size() == n) {
ans++;
return;
}
for (int i = 0; i < s.length(); i++) {
if (vis[i] == 1 || (i > 0 && s.charAt(i) == s.charAt(i - 1) && vis[i - 1] == 0) || (cur.size() > 0 && cur.get(cur.size() - 1) == s.charAt(i))) {
continue;
}
vis[i] = 1;
cur.add(s.charAt(i));
dfs(cur, s);
cur.remove(cur.size() - 1);
vis[i] = 0;
}
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
String str = scanner.next();
n = scanner.nextInt();
vis = new int[31];
Arrays.fill(vis, 0);
char[] charArray = str.toCharArray();
Arrays.sort(charArray);
str = new String(charArray);
List<Character> cur = new ArrayList<>();
dfs(cur, str);
System.out.println(ans);
}
}
Python 版
ans = 0
n = 0
vis = [0] * 31
cur = []
s = ''
def dfs():
global ans
if len(cur) == n:
ans += 1
return
for i in range(len(s)):
if vis[i] or (i > 0 and s[i] == s[i - 1] and not vis[i - 1]) or (len(cur) > 0 and cur[-1] == s[i]):
continue
vis[i] = 1
cur.append(s[i])
dfs()
cur.pop()
vis[i] = 0
if __name__ == "__main__":
str_input = input().split()
str = str_input[0]
n = int(str_input[1])
vis = [0] * 31
s = s.join(sorted(str))
dfs()
print(ans)