理论基础
有递归就有回溯。回溯搜索是一种纯暴力搜索算法。我们一层一层递归到最底层收获结果,比如下面我们最后一层+1操作之后,我们只有撤销这个操作回退到上一个节点才能遍历该层的其他节点,这个回退撤销操作就是回溯。
回溯法,一般可以解决如下几种问题:
组合问题:N个数里面按一定规则找出k个数的集合
切割问题:一个字符串按一定规则有几种切割方式
子集问题:一个N个数的集合里有多少符合条件的子集
排列问题:N个数按一定规则全排列,有几种排列方式
棋盘问题:N皇后,解数独等等
Note:组合是不强调元素顺序的,排列是强调元素顺序。
回溯法解决的问题都可以抽象为树形结构。从上一层到下一层就是一次递归操作。每一层的节点数量就是该层递归操作了多少次或者循环了多少次用来确定for的参数。确定递归函数的参数和返回值,思考我们在确定终止条件、for循环和单层递归操作的时候需要使用哪些参数。递归终止条件根据题目要求,在图中表示为树的深度。单层递归逻辑,需要怎么处理节点,怎么得到或者不断完善结果。
回溯三部曲:
- 确定递归函数的参数和返回值
- 确定递归终止条件
- 单层递归逻辑
回溯算法的模板如下:
void backtracking(参数){//参数需要什么就添加什么
if(终止条件){//递归终止条件,也就是题目要求
存放结果;
return;
}
for(选择:本层集合中的元素(树中节点孩子的数量就是集合的大小)){
处理节点;//单层递归处理逻辑在这里
backtracking(路径, 选择列表);//递归
回溯,撤销处理结果//回溯,撤销我们处理节点的操作
}
}
参考文章
- https://www.programmercarl.com/%E5%9B%9E%E6%BA%AF%E7%AE%97%E6%B3%95%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80.html#%E7%90%86%E8%AE%BA%E5%9F%BA%E7%A1%80
77. 组合
回溯三部曲:
- 确定递归函数的参数和返回值:n, k, startIndex(因为组合的定义是无序的,而且元素不允许重复,所以定义一个startIndex来指明之前收获到哪个元素,避免得到重复的组合)。无返回值(结果都定义为全局变量了,path保存当前的结果一维数组,result保存符合条件的结果的集合二维数组)
- 确定递归终止条件:path.size()==k说明得到符合条件的结果了就保存结果result.push(path);并停止向下递归。(需要的组合的长度决定递归的深度)
- 单层递归逻辑:path.push(i);把当前的元素放到结果列表中。
剪枝操作:避免没有意义的递归操作,在for (int i = startIndex; i <= n; i++),把n改成n-(k-path.size())+1,因为我们要取n个值,但是如果当前剩下的可供我们选择的数不足就不需要继续进行递归操作了。注意i是可以取到这个值,说明这个是可以提取到n个值的所以有一个+1的操作。
下面是C++, JAVA, Python的代码。
class Solution {
// // 下面是进行剪枝优化的
private:
vector<vector<int>> result;//
vector<int> path;
void backtracking(int n, int k, int startIndex){
if(path.size() == k){
result.push_back(path);
return;
}//这个是终止条件
for(int i = startIndex; i <= n- (k - path.size())+1; i++){//优化的地方
path.push_back(i);
backtracking(n, k , i+1);
path.pop_back();//回溯
}
}
public:
vector<vector<int>> combine(int n, int k){
backtracking(n, k, 1);
return result;
}
// //下面是没有进行优化的
// private:
// vector<vector<int>> result;//存放符合条件结果的集合
// vector<int> path;//用来存放符合条件的结果
// void backtracking(int n, int k, int startIndex){//回溯算法的参数,需要什么定义什么,n是书的个数,k是结果的长度,startIndex是当前循环的起始位置。为了防止重复把之前遍历过的筛选掉
// if (path.size() == k){
// result.pash_back(path);//存放结果
// return;
// }//递归终止条件,只要找到符合长度的结果就不要继续向下递归了
// for (int i = startIndex; i <= n; i++){
// path.push_back(i);
// backtracking(n, k, i+1);//递归
// path.pop_back();//回溯,撤销处理的节点
// }
// }
// public:
// vector<vector<int>> combine(int n, int k) {
// result.clear();
// path.clear();
// backtracking(n, k, 1);
// return result;
// }
};
class Solution {
// // 下面是进行剪枝优化的
List<List<Integer>> result = new ArrayList<>();
LinkedList<Integer> path = new LinkedList<>();
public List<List<Integer>> combine(int n, int k){
backtracking(n, k, 1);
return result;
}
public void backtracking(int n, int k, int startIndex){
if(path.size() == k){
result.add(new ArrayList<>(path));
return;
}
for(int i = startIndex; i <= n- (k - path.size()) + 1; i++){
path.add(i);
backtracking(n, k, i+1);
path.removeLast();
}
}
// // 下面是没有进行剪枝优化的
// List<List<Integer>> result = new ArrayList<>();//存放符合条件的结果列表
// List<Integer> path = new LinkedList<>();//存放当前的结果
// public List<List<Integer>> combine(int n, int k) {
// backtracking(n, k, 1);
// return result;
// }
// public void backtracking(int n, int k, int startIndex){
// if(path.size() == k){
// result.add(new ArrayList<>(path));//如果都赋值为path那么这个result中所有元素都指向一个位置,随着path的变化result也会变化
// return;
// }
// for(int i = startIndex; i<=n; i++){
// path.add(i);
// backtracking(n, k, i+1);
// path.removeLast();//回溯
// }
// }
}
class Solution(object):
def combine(self, n, k):
result = []#存放结果集
self.backtracking(n, k, 1, [], result)
return result
def backtracking(self, n, k, startIndex, path, result):
if len(path) == k:
result.append(path[:])
return
for i in range(startIndex, n- (k -len(path))+2):#因为右区间取不到所以大一个
path.append(i)#处理节点
self.backtracking(n, k, i+1, path, result)
path.pop()
# def combine(self, n, k):
# """
# :type n: int
# :type k: int
# :rtype: List[List[int]]
# """
# result = [] #存放结果
# self.backtracking(n, k, 1, [], result)
# return result
# def backtracking(self, n, k, startIndex, path, result):
# if len(path) == k:
# result.append(path[:])
# return
# for i in range(startIndex, n+1):#需要优化的地方
# path.append(i) #处理节点
# self.backtracking(n, k, i+1, path, result)
# path.pop()
参考文章
- https://www.programmercarl.com/0077.%E7%BB%84%E5%90%88.html#%E7%AE%97%E6%B3%95%E5%85%AC%E5%BC%80%E8%AF%BE
216.组合总和III
和组合几乎一样,就是判断条件不同。需要定义一个sum保存当前的和值,来判断是否满足目标和值。
回溯三部曲:
- 确定递归函数的参数和返回值:n, k, startIndex(因为组合的定义是无序的,而且元素不允许重复,所以定义一个startIndex来指明之前收获到哪个元素,避免得到重复的组合)。无返回值(结果都定义为全局变量了,path保存当前的结果一维数组,result保存符合条件的结果的集合二维数组)
- 确定递归终止条件:path.size()==k说明得到组合的长度符合条件,如果 t a r g e t S u m = = s u m targetSum==sum targetSum==sum将结果保存下来,result.push_back(path);
- 单层递归逻辑:path.push(i);把当前的元素放到结果列表中。sum+=i;//把当前的结果列表更新。
剪枝:从和值和集合个数两个维度进行剪枝。一个是在函数最开始加入if(sum>targetSum) return; #如果当前的和值已经大于目标和值了就直接返回。i<=9这里可以修改成9-(k-path.size())+1,也就是如果当前可以选择的数如果不足以得到符合条件的结果就直接返回。
class Solution {
private:
vector<vector<int>> result;//存放结果集
vector<int> path;//符合条件的结果
void backtracking(int targetSum, int k, int sum, int startIndex){
if(sum > targetSum){//剪枝操作
return;
}
if(path.size()==k){
if(sum == targetSum) result.push_back(path);
return;
}
for(int i = startIndex; i <= 9 - ( k - path.size())+1; i++){//剪枝操作
sum += i;//处理
path.push_back(i);//处理
backtracking(targetSum, k, sum, i + 1);
sum -= i;//回溯
path.pop_back();//回溯
}
}
public:
vector<vector<int>> combinationSum3(int k, int n){
result.clear();
path.clear();
backtracking(n, k, 0, 1);
return result;
}
// private:
// vector<vector<int>> result;//
// vector<int> path;
// void backtracking(int n, int k, int startIndex){
// if(path.size() == k){
// int sum = std::accumulate(path.begin(), path.end(), 0);
// if(sum == n) result.push_back(path);
// return;
// }//这个是终止条件
// for(int i = startIndex; i <= 9- (k - path.size())+1; i++){//优化的地方
// path.push_back(i);
// backtracking(n, k , i+1);
// path.pop_back();//回溯
// }
// }
// public:
// vector<vector<int>> combinationSum3(int k, int n) {
// backtracking(n, k, 1);
// return result;
// }
};
class Solution {
List<List<Integer>> result = new ArrayList<>();
LinkedList<Integer> path = new LinkedList<>();
public List<List<Integer>> combinationSum3(int k, int n) {
backTracking(n, k, 1, 0);
return result;
}
private void backTracking(int targetSum, int k, int startIndex, int sum){
//剪枝
if(sum > targetSum){
return;
}
if(path.size() == k){
if(sum == targetSum) result.add(new ArrayList<>(path));
return;
}
//剪枝 9 - (k - path.size()) +1
for(int i = startIndex; i <= 9 - (k - path.size())+1; i++){
path.add(i);
sum += i;
backTracking(targetSum, k, i+1, sum);
sum -= i;//回溯
path.removeLast();//回溯
}
}
}
class Solution(object):
def combinationSum3(self, k, n):
"""
:type k: int
:type n: int
:rtype: List[List[int]]
"""
result = []
self.backtracking(n, k, 0, 1, [], result)
return result
def backtracking(self, targetSum, k, currentSum, startIndex, path, result):
if currentSum > targetSum: #剪枝操作
return#如果当前的和值已经大于目标值了就不需要向下递归了
if len(path) == k:
if currentSum == targetSum:
result.append(path[:])
return
for i in range(startIndex, 9 - (k - len(path))+2):#剪枝
currentSum += i
path.append(i)
self.backtracking(targetSum, k, currentSum, i+1, path, result)
currentSum -= i
path.pop()
参考文章
- https://www.programmercarl.com/0216.%E7%BB%84%E5%90%88%E6%80%BB%E5%92%8CIII.html
17.电话号码的字母组合
回溯三部曲:
- 确定递归函数的参数和返回值:参数是当前输入的数字digits和当前处理的数字的索引index。没有返回值,因为我们定义的全局变量保存单个结果和结果列表。string s;收获单个结果vector result;把每个符合条件的结果保存在结果集。与前两个不同的是没有startIndex,因为之前的是在一个集合中收集元素所以需要startIndex指明之前收获到哪个元素了,避免得到重复的组合。但是这个是在不同的字符集合中进行收集,不需要startIndex去控制集合中我们之前遍历过哪些元素。
- 确定递归终止条件:index==digits.size();说明已经处理完所有输入的数字了,保存结果result.push_back(s)
- 单层递归逻辑:digit=digits[index]-'0’获得当前处理的字符的int类型数据。string letter = letterMap[digit]获得当前处理的数字对应的字符。遍历字符并将当前遍历的字符letter[i] push到s中。
无剪枝操作。
class Solution {
private:
const string letterMap[10] = {
"",//0
"",//1
"abc",//2
"def",//3
"ghi",//4
"jkl",//5
"mno",//6
"pqrs",//7
"tuv",//8
"wxyz",//9
};
string s;//收获单个结果
vector<string> result;//把每个符合条件的结果放在结果集
void backtracking(const string& digits, int index){
if(index==digits.size()){//遍历到哪个数字
result.push_back(s);
return;
}
int digit = digits[index]-'0';
string letter = letterMap[digit];
for(int i = 0; i< letter.size(); i++){
s.push_back(letter[i]);
backtracking(digits, index+1);
s.pop_back();//回溯
}
}
public:
vector<string> letterCombinations(string digits) {
s.clear();
result.clear();
if(digits.size() == 0){
return result;
}
backtracking(digits, 0);
return result;
}
// // 下面隐藏了回溯过程
// private:
// const string letterMap[10] = {
// "",//0
// "",//1
// "abc",//2
// "def",//3
// "ghi",//4
// "jkl",//5
// "mno",//6
// "pqrs",//7
// "tuv",//8
// "wxyz",//9
// };
// public:
// vector<string> result;
// string s;
// void getCombinations(const string& digits, int index, const string& s){
// if(index==digits.size()){
// result.push_back(s);
// return;
// }
// int digit = digits[index] - '0';//将index的字母转换为int
// string letters = letterMap[digit];//取数字对应的字符集
// for (int i = 0; i<letters.size(); i++){
// getCombinations(digits, index+1, s+letters[i]);
// }
// }
// vector<string> letterCombinations(string digits){
// s.clear();
// result.clear();
// if(digits.size()==0){
// return result;
// }
// getCombinations(digits, 0, "");
// return result;
// }
};
class Solution {
//设置全局列表存储最后结果
List<String> list = new ArrayList<>();
public List<String> letterCombinations(String digits) {
if(digits==null || digits.length() == 0){
return list;
}
//初始对应所有的数字,为了直接对应2-9,新增了两个无效的字符串""
String[] numString = {"", "", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"};
//迭代处理
backtracking(digits, numString, 0);
return list;
}
//每次迭代获取一个字符串,所以会涉及大量的字符串拼接,这里选择更加高效的StringBuilder
StringBuilder temp = new StringBuilder();//存储当前产生的字符串的吗
public void backtracking(String digits, String[] numString, int num){
//遍历全部一次记录一次得到的字符串
if(num == digits.length()){//这个就是递归的终止条件
list.add(temp.toString());
return;
}
//str 表示当前num对应的字符串
String str = numString[digits.charAt(num)-'0'];
for(int i = 0; i < str.length(); i++){
temp.append(str.charAt(i));
//递归,处理下一层
backtracking(digits, numString, num + 1);
//剔除末尾的继续尝试
temp.deleteCharAt(temp.length() -1);
}
}
}
class Solution(object):
def __init__(self):
self.letterMap = [
"", # 0
"", # 1
"abc", # 2
"def", # 3
"ghi", # 4
"jkl", # 5
"mno", # 6
"pqrs", # 7
"tuv", # 8
"wxyz" # 9
]
self.result = []
self.s = ""
def backtracking(self, digits, index):
if index == len(digits):
self.result.append(self.s)
return
digit = int(digits[index])#将索引处的数字转换为整数
letters = self.letterMap[digit]#获取对应的字符集
for i in range(len(letters)):
self.s += letters[i]
self.backtracking(digits, index+1)#递归调用,注意索引加1
self.s = self.s[:-1] #回溯,删除最后添加的字符
def letterCombinations(self, digits):
"""
:type digits: str
:rtype: List[str]
"""
if len(digits) == 0:
return self.result
self.backtracking(digits, 0)
return self.result
参考文章
- https://www.programmercarl.com/0017.%E7%94%B5%E8%AF%9D%E5%8F%B7%E7%A0%81%E7%9A%84%E5%AD%97%E6%AF%8D%E7%BB%84%E5%90%88.html