一、源码及框架分析
SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等及个头文件中。
map和set的实现结构框架核心部分截取出来如下:
// set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>
// map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>
// stl_set.h
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
// typedefs:
typedef Key key_type;
typedef Key value_type;
private:
typedef rb_tree<key_type, value_type, identity<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing set
};
// stl_map.h
template <class Key, class T, class Compare = less<Key>, class Alloc = alloc>
class map {
public:
// typedefs:
typedef Key key_type;
typedef T mapped_type;
typedef pair<const Key, T> value_type;
private:
typedef rb_tree<key_type, value_type, select1st<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing map
};
// stl_tree.h
struct __rb_tree_node_base
{
typedef __rb_tree_color_type color_type;
typedef __rb_tree_node_base* base_ptr;
color_type color;
base_ptr parent;
base_ptr left;
base_ptr right;
};
// stl_tree.h
template <class Key, class Value, class KeyOfValue, class Compare, class Alloc = alloc>
class rb_tree {
protected:
typedef void* void_pointer;
typedef __rb_tree_node_base* base_ptr;
typedef __rb_tree_node<Value> rb_tree_node;
typedef rb_tree_node* link_type;
typedef Key key_type;
typedef Value value_type;
public:
// insert⽤的是第⼆个模板参数做形参
pair<iterator,bool> insert_unique(const value_type& x);
// erase和find⽤第⼀个模板参数做形参
size_type erase(const key_type& x);
iterator find(const key_type& x);
protected:
size_type node_count; // keeps track of size of tree
link_type header;
};
template <class Value>
struct __rb_tree_node : public __rb_tree_node_base
{
typedef __rb_tree_node<Value>* link_type;
Value value_field;
};
• 通过下图对框架的分析,我们可以看到源码中rb_tree用了一个巧妙的泛型思想实现,rb_tree是实现key的搜索场景,还是key/value的搜索场景不是直接写死的,而是由第二个模板参数Value决定_rb_tree_node中存储的数据类型。
• set实例化rb_tree时第二个模板参数给的是key,map实例化rb_tree时第二个模板参数给的是pair<const key , T>,这样一颗红黑树既可以实现key搜索场景的set,也可以实现key/value搜索场景的map。
• 要注意⼀下,源码里面模板参数是用T代表value,而内部写的value_type不是我们我们日常key/value场景中说的value,源码中的value_type反而是红黑树结点中存储的真实的数据的类型。
• rb_tree第二个模板参数Value已经控制了红黑树结点中存储的数据类型,为什么还要传第一个模板参数Key呢?尤其是set,两个模板参数是⼀样的,这是很多同学这时的一个疑问。要注意的是对于map和set,find/erase时的函数参数都是Key,所以第一个模板参数是传给find/erase等函数做形参的类型的。对于set而言两个参数是一样的,但是对于map而言就完全不一样了,map insert的是pair对象,但是find和ease的是Key对象。
二、模拟实现 set 和 map
2.1 实现出符合要求的红黑树
2.1.1 iterator 与 reverse_iterator
iterator实现思路分析
• iterator实现的大框架跟list的iterator思路是一致的,用一个类型封装结点的指针,再通过重载运算符实现迭代器像指针一样访问的行为。
• 这里的难点是operator++和operator–的实现。之前使用部分,我们分析了,map和set的迭代器走的是中序遍历,左子树->根结点->右子树,那么begin()会返回中序第一个结点的iterator也就是最左节点的迭代器。
• 迭代器++的核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下一个结点。迭代器++时,如果it指向的结点的右子树不为空,代表当前结点已经访问完了,要访问下一个结点是右子树的中序第一个,一棵树中序第一个是最左结点,所以直接找右子树的最左结点即可;如果it指向的结点的右子树空,代表当前结点已经访问完了且当前结点所在的子树也访问完了,要访问的下一个结点在当前结点的祖先里面,所以要沿着当前结点到根的祖先路径向上找。如果当前结点是父亲的左,根据中序左子树->根结点->右子树,那么下一个访问的结点就是当前结点的父亲;如果当前结点是父亲的右,根据中序左子树->根结点->右子树,当前当前结点所在的子树访问完了,当前结点所在父亲的子树也访问完了,那么下一个访问的需要继续往根的祖先中去找,直到找到孩子是父亲左的那个祖先就是中序要遍历的下一个结点。
• end()如何表示呢?stl源码中,红黑树增加了一个哨兵位头结点做为end(),这哨兵位头结点和根互为父亲,左指向最左结点,右指向最右结点。
• 迭代器–的实现跟++的思路完全类似,逻辑正好反过来即可,因为他访问顺序是右子树->根结点->左子树。
• set的iterator也不支持修改,我们把set的第二个模板参数改成const K即可, RBTree<K, const K, SetKeyOfT> _t;
• map的iterator不支持修改key但是可以修改value,我们把map的第二个模板参数pair的第一个参数改成const K即可, RBTree<K, pair<const K, V>, MapKeyOfT> _t;
• 至于reverse_iterator的实现和iterator类似,就不多赘述,具体实现看下面的代码。
通过上面的叙述我们了解到整体需要从结点的定义开始改造,即定义哨兵结点。
2.1.2 改造红黑树的代码
#pragma once
#include<iostream>
#include<assert.h>
using namespace std;
//节点颜色
enum Colour
{
RED,
BLACK
};
//红黑树的节点
//K是用来查找的,T是用来插入的
//set中T是key,map中T是pair<key,value>
template<class T>
struct RBTreeNode
{
T _t;//data
RBTreeNode* _left;
RBTreeNode* _right;
RBTreeNode* _parent;
Colour _col;
RBTreeNode(const T& t = T(), Colour col = RED)
:_t(t),
_left(nullptr),
_right(nullptr),
_parent(nullptr),
_col(col)
{}
};
//红黑树的迭代器
template<class T, class Ref, class Ptr>
struct RBTreeIterator
{
typedef RBTreeNode<T> Node;
Node* _node;//要操作的节点
RBTreeIterator(Node* node) :_node(node) {}
typedef RBTreeIterator Self;
//++it
Self& operator++()
{
Node* cur = _node;
if (cur->_right)
{
Node* LeftMost = cur->_right;
while (LeftMost->_left)
{
LeftMost = LeftMost->_left;
}
_node = LeftMost;
}
else
{
Node* parent = cur->_parent;
while (parent->_right == cur)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
//it++
Self operator++(int)
{
Self temp(*this);
++*this;
return temp;
}
//--it
Self& operator--()
{
Node* cur = _node;
//end()时
if (_node->_parent->_parent == _node && _node->_col == RED)
{
_node = _node->_right;
}
else if (cur->_left)
{
Node* RightMost = cur->_left;
while (RightMost->_right)
{
RightMost = RightMost->_right;
}
_node = RightMost;
}
else
{
Node* parent = cur->_parent;
while (parent->_left == cur)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
//it--
Self operator--(int)
{
Self temp(*this);
--*this;
return temp;
}
Ref operator*()
{
return _node->_t;
}
Ptr operator->()
{
return &_node->_t;
}
bool operator!=(const Self& it)
{
return _node != it._node;
}
bool operator==(const Self& it)
{
return _node == it._node;
}
};
//红黑树的反向迭代器
template<class T, class Ref, class Ptr>
struct RBTreeReverseIterator
{
typedef RBTreeNode<T> Node;
Node* _node;//要操作的节点
RBTreeReverseIterator(Node* node) :_node(node) {}
typedef RBTreeReverseIterator Self;
//++it
Self& operator--()
{
Node* cur = _node;
if (cur->_right)
{
Node* LeftMost = cur->_right;
while (LeftMost->_left)
{
LeftMost = LeftMost->_left;
}
_node = LeftMost;
}
else
{
Node* parent = cur->_parent;
while (parent->_right == cur)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
//it++
Self operator++(int)
{
Self temp(*this);
++*this;
return temp;
}
//--it
Self& operator++()
{
Node* cur = _node;
//end()时
if (_node->_parent->_parent == _node && _node->_col == RED)
{
_node = _node->_right;
}
else if (cur->_left)
{
Node* RightMost = cur->_left;
while (RightMost->_right)
{
RightMost = RightMost->_right;
}
_node = RightMost;
}
else
{
Node* parent = cur->_parent;
while (parent->_left == cur)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
//it--
Self operator--(int)
{
Self temp(*this);
--*this;
return temp;
}
Ref operator*()
{
return _node->_t;
}
Ptr operator->()
{
return &_node->_t;
}
bool operator!=(const Self& it)
{
return _node != it._node;
}
bool operator==(const Self& it)
{
return _node == it._node;
}
};
//红黑树
template<class K, class T, class KeyOfT>
class RBTree
{
typedef RBTreeNode< T> Node;
void destory(Node* root)
{
if (root == nullptr)
return;
//必须后续销毁节点
destory(root->_left);
root->_left = nullptr;
destory(root->_right);
root->_right = nullptr;
delete root;
root = nullptr;
}
public:
RBTree()
{
_header = new Node(T());
}
//拷贝构造
RBTree(const RBTree& rbt)
{
_header = new Node(T());
for (const auto e : rbt)
{
insert(e);
}
}
RBTree(initializer_list<T> il)
{
_header = new Node(T());
for (const auto& e : il)
insert(e);
}
//赋值重载
RBTree& operator=(const RBTree& bst)
{
//现代写法
RBTree temp(bst);
std::swap(_header, temp._header);
return *this;
}
// 析构
~RBTree()
{
destory(_header->_parent);
delete _header;
_header = nullptr;
}
typedef RBTreeIterator<T, T&, T*> iterator;
typedef RBTreeIterator<T, const T&, const T*> const_iterator;
typedef RBTreeReverseIterator<T, T&, T*> reverse_iterator;
typedef RBTreeReverseIterator<T, const T&, const T*> const_reverse_iterator;
//迭代器
iterator begin()
{
return iterator(_header->_left);
}
const_iterator begin() const
{
return const_iterator(_header->_left);
}
iterator end()
{
return iterator(_header);
}
const_iterator end() const
{
return const_iterator(_header);
}
//反向迭代器
reverse_iterator rbegin()
{
return reverse_iterator(_header->_right);
}
const_reverse_iterator rbegin()const
{
return const_reverse_iterator(_header->_right);
}
reverse_iterator rend()
{
return reverse_iterator(_header);
}
const_reverse_iterator rend()const
{
return const_reverse_iterator(_header);
}
//此时需要注意key是set中的key,是map中pair<key,value>中的key,
// 因此我们需要取出想要的key
iterator find(const K& key) const
{
Node* cur = _header->_parent;//指向根节点
while (cur)
{
if (key < KeyOfT()(cur->_t))
{
cur = cur->_left;
}
else if (key > KeyOfT()(cur->_t))
{
cur = cur->_right;
}
else
{
return iterator(cur);
}
}
return iterator(_header);
}
pair<iterator, bool> insert(const T& t)
{
if (_header->_parent == nullptr)
{
_header->_parent = new Node(t);
_header->_parent->_col = BLACK;
_header->_parent->_parent = _header;
return { iterator(_header->_parent),true };
}
//查找结点
Node* parent = _header;
Node* cur = _header->_parent;
while (cur)
{
if (KeyOfT()(t) < KeyOfT()(cur->_t))
{
parent = cur;
cur = cur->_left;
}
else if (KeyOfT()(t) > KeyOfT()(cur->_t))
{
parent = cur;
cur = cur->_right;
}
else
{
return { iterator(cur), false };
}
}
//插入节点
cur = new Node(t);
cur->_parent = parent;
if (KeyOfT()(t) < KeyOfT()(parent->_t))
parent->_left = cur;
else
parent->_right = cur;
//保存cur,留待返回
Node* ret = cur;
//调整颜色
//情况一:parent存在且为黑,不需要调整
//情况二:parent存在且为红,需要调整
while (parent != _header && parent->_col == RED)
{
//确定 g 和 u 节点
Node* grandfather = parent->_parent;
Node* uncle = nullptr;
if (parent == grandfather->_left)
uncle = grandfather->_right;
else
uncle = grandfather->_left;
//在情况二下又有如下情况:
//情况1:unclude存在且为红
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
//继续向上调整
cur = grandfather;
parent = cur->_parent;
}
//情况2:uncle不存在或uncle为黑
else if (!uncle || uncle->_col == BLACK)
{
//parent 是 grandfather 的左孩子
if (parent == grandfather->_left)
{
//cur 是 parent 的左孩子
if (cur == parent->_left)
{
//旋转加变色
RotateRight(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
//cur 是 parent 的右孩子
else
{
//旋转加变色
RotateLeftThenRight(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
}
//parent 是 grandfather 的右孩子
else
{
//cur 是 parent 的左孩子
if (cur == parent->_left)
{
//旋转加变色
RotateRightThenLeft(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
//cur 是 parent 的右孩子
else
{
//旋转加变色
RotateLeft(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
}
break;//调整好能退出了
}
else
{
assert(false);//逻辑错误,理论上不可能走到
}
}
//保证根节点必须是黑色的
_header->_parent->_col = BLACK;
_header->_left = LeftMost();
_header->_right = RightMost();
return { iterator(ret), true };
}
bool empty()
{
return _header->_parent == nullptr;
}
size_t size()
{
size_t count = 0;
auto it = begin();
while (it != end())
{
count++;
++it;
}
return count;
}
void clear()
{
destory(_header->_parent);
_header->_parent = nullptr;
}
private:
Node* LeftMost()
{
Node* MostLeftChild = _header->_parent;
while (MostLeftChild->_left)
{
MostLeftChild = MostLeftChild->_left;
}
return MostLeftChild;
}
Node* RightMost()
{
Node* MostRightChild = _header->_parent;
while (MostRightChild->_right)
{
MostRightChild = MostRightChild->_right;
}
return MostRightChild;
}
void RotateLeft(Node* parent)
{
//保存节点,后面链接
Node* parentparent = parent->_parent;
Node* sub = parent->_right;
Node* subleft = sub->_left;
//重新链接
parent->_right = subleft;
if (subleft)
subleft->_parent = parent;
sub->_left = parent;
parent->_parent = sub;
//和上面的节点链接
if (parentparent == _header)
{
//说明parent原来是根节点
_header->_parent = sub;
sub->_parent = _header;
}
else
{
if (parentparent->_left == parent)
{
//原来是上面节点的左子树
parentparent->_left = sub;
}
else
{
parentparent->_right = sub;
}
sub->_parent = parentparent;
}
}
void RotateRight(Node* parent)
{
//保存节点,后面链接
Node* parentparent = parent->_parent;
Node* sub = parent->_left;
Node* subright = sub->_right;
//重新链接
parent->_left = subright;
if (subright)
subright->_parent = parent;
sub->_right = parent;
parent->_parent = sub;
//和上面的节点链接
if (parentparent == _header)
{
//说明parent原来是根节点
_header->_parent = sub;
sub->_parent = _header;
}
else
{
if (parentparent->_left == parent)
{
//原来是上面节点的左子树
parentparent->_left = sub;
}
else
{
parentparent->_right = sub;
}
sub->_parent = parentparent;
}
}
void RotateLeftThenRight(Node* parent)
{
RotateLeft(parent->_left);
RotateRight(parent);
}
void RotateRightThenLeft(Node* parent)
{
RotateRight(parent->_right);
RotateLeft(parent);
}
private:
Node* _header;
};
2.2 复用红黑树实现 set
#pragma once
#include"RBTree.h"
template<class K>
class set
{
typedef const K T;//用以插入
struct KeyOfT
{
const K& operator()(const T& key)
{
return key;
}
};
typedef typename RBTree<K, T, KeyOfT>::iterator iterator;
typedef typename RBTree<K, T, KeyOfT>::const_iterator const_iterator;
typedef typename RBTree<K, T, KeyOfT>::reverse_iterator reverse_iterator;
typedef typename RBTree<K, T, KeyOfT>::const_reverse_iterator const_reverse_iterator;
public:
set() = default;
set(initializer_list<T> il)
{
_rbt = { il };
}
//迭代器
iterator begin()
{
return _rbt.begin();
}
const_iterator begin()const
{
return _rbt.begin();
}
iterator end()
{
return _rbt.end();
}
const_iterator end() const
{
return _rbt.end();
}
//反向迭代器
reverse_iterator rbegin()
{
return _rbt.rbegin();
}
const_reverse_iterator rbegin()const
{
return _rbt.rbegin();
}
reverse_iterator rend()
{
return _rbt.rend();
}
const_reverse_iterator rend() const
{
return _rbt.rend();
}
bool empty()
{
return _rbt.empty();
}
size_t size()
{
return _rbt.size();
}
void clear()
{
_rbt.clear();
}
iterator find(const K& key)const
{
return _rbt.find(key);
}
pair<iterator, bool> insert(const T& t)
{
return _rbt.insert(t);
}
private:
RBTree<K, T, KeyOfT> _rbt;
};
2.3 复用红黑树实现 map
#pragma once
#include"RBTree.h"
template<class K, class V>
class map
{
typedef pair<const K, V> T;
struct KeyOfT
{
const K& operator()(const T& data)
{
return data.first;
}
};
typedef typename RBTree<K, T, KeyOfT>::iterator iterator;
typedef typename RBTree<K, T, KeyOfT>::const_iterator const_iterator;
typedef typename RBTree<K, T, KeyOfT>::reverse_iterator reverse_iterator;
typedef typename RBTree<K, T, KeyOfT>::const_reverse_iterator const_reverse_iterator;
public:
map() = default;
map(initializer_list<T> il)
{
_rbt = { il };
}
//迭代器
iterator begin()
{
return _rbt.begin();
}
const_iterator begin()const
{
return _rbt.begin();
}
iterator end()
{
return _rbt.end();
}
const_iterator end() const
{
return _rbt.end();
}
//反向迭代器
reverse_iterator rbegin()
{
return _rbt.rbegin();
}
const_reverse_iterator rbegin()const
{
return _rbt.rbegin();
}
reverse_iterator rend()
{
return _rbt.rend();
}
const_reverse_iterator rend() const
{
return _rbt.rend();
}
bool empty()
{
return _rbt.empty();
}
size_t size()
{
return _rbt.size();
}
void clear()
{
_rbt.clear();
}
iterator find(const K& key)const
{
return _rbt.find(key);
}
pair<iterator, bool> insert(const pair<const K, V>& kv)
{
return _rbt.insert(kv);
}
V& operator[](const K& key)
{
pair<iterator, bool> ret = insert({ key, V() });
iterator it = ret.first;
return it->second;
}
private:
RBTree < K, pair<const K, V>, KeyOfT> _rbt;//key 不能修改
};