树堆(Treap)

2024/12/27

由于最近在力扣上刷题,碰到了这个题目:1847. 最近的房间

这个题目里面需要一个结构来支持边插入边排序,并且还要求较高的查询效率。由于我用的是 GoLang, 标准库里没有 TreeSet 之类的结构,所以不得不手搓一个。

虽然 GoLang 内置了 heap 的实现,只不过这个只能用于优先队列,肯定不符合我们的要求。

在这里我优先想到的不是直接用堆来实现,因为感觉有点大材小用了,所以我尝试用跳表来解决,因为这两者理论上的时间复杂度差不多。

结果。。。超时了。

1e5 + 1e4 的数据量, 1e4 的查询,需要耗费 18 秒左右。。。我打印了一下日志计算了一下,发现插入是性能瓶颈, 部分节点的插入耗时甚至可以超过 1 秒钟。

感兴趣的话可以看一下我的代码:

package q1847 import ( "math/rand" "slices" "strconv" "time" ) var maxLevel = 32 const factor = 0.25 const headNodeVal = -99999999 func randomLevel() int { lv := 1 for lv < maxLevel && rand.Float64() > factor { lv++ } return lv } type SkipListNode struct { val int forward []*SkipListNode } type SkipList struct { head *SkipListNode level int } func (s *SkipList) findMatch(base int) int { current := s.head for i := s.level - 1; i >= 0; i-- { fwd := current.forward[i] for fwd != nil && fwd.val < base { current = fwd fwd = current.forward[i] } } if current.forward[0] == nil { if current == s.head { return -1 } return current.val } fwd := current.forward[0] if base-current.val <= fwd.val-base { return current.val } return fwd.val } func (s *SkipList) add(val int) { lv := randomLevel() s.level = max(s.level, lv) node := &SkipListNode{ val: val, forward: make([]*SkipListNode, lv), } current := s.head for i := lv - 1; i >= 0; i-- { fwd := current.forward[i] for fwd != nil && fwd.val < val { current = fwd fwd = current.forward[i] } node.forward[i] = current.forward[i] current.forward[i] = node } } type Event struct { id, size, queryPos int } func closestRoom(rooms [][]int, queries [][]int) []int { events := make([]*Event, len(rooms)+len(queries)) pos := 0 result := make([]int, len(queries)) for i := range rooms { events[pos] = &Event{ id: rooms[i][0], size: rooms[i][1], queryPos: -1, } pos++ } for i := range queries { events[pos] = &Event{ id: queries[i][0], size: queries[i][1], queryPos: i, } pos++ } slices.SortStableFunc(events, func(a, b *Event) int { if a.size == b.size { return a.queryPos - b.queryPos } return b.size - a.size }) skipList := SkipList{ head: &SkipListNode{ val: headNodeVal, forward: make([]*SkipListNode, maxLevel), }, level: 0, } for _, evt := range events { if evt.queryPos == -1 { start := time.Now().Unix() skipList.add(evt.id) sp := time.Now().Unix() - start if sp > 0 { println("Insert: " + strconv.FormatInt(sp, 10)) } continue } start := time.Now().Unix() result[evt.queryPos] = skipList.findMatch(evt.id) sp := time.Now().Unix() - start if sp > 0 { println("Query: " + strconv.FormatInt(sp, 10)) } } return result }
go

去网上搜了一下,发现跳表相对于堆更节省内容,但是其它方面是不如堆的。对于我们敲算法的,只要内存只要用不爆,就往死里用。。。所以还是得整一个堆,但是对于我们刷力扣的,一般用不上那么高大尚的,所以简单好写的树堆就是我们的最佳选择了!

基础概念

堆,它的特性为堆上的任意一个节点的值,一定大于/小于其所有的子节点,如果是小于,则是小根堆,反之则是大跟堆。我们的优先队列就是使用堆来实现的。

想要实现一个堆也非常简单,我们利用数组的特性,开辟一个节点数 * 4大小的数组,对于任意一个索引位置 i, 它的相关节点分别为:

  • 左节点: i << 1i * 2
  • 右节点: i << 1 | 1i * 2 + 1
  • 父节点: i >> 1i / 2

写位运算逼格一下就拉上来了

我们在维护堆的时候,只需保存最后一个索引位就行了,相关操作大致为:

  • 插入:和父节点比较,不满足条件就和父节点交换值。
  • 删除:将最后一个节点的值替换到根节点,之后比较左右两个节点是否满足要求,不满足则交换值后继续向下递归。

树堆

堆讲完了,那么它和树堆有什么关系呢?

在树堆中,每个节点都会有一个权值,这个权值完全随机,在节点创建时确定,说白了就是在创建节点的时候 roll 一个权值出来保存到节点里面。

有了权值,堆就来干活了!我们在维护树堆的同时必须保证其满足堆的特性,除此之外,我们还需要维护它作为树的特性,即左小右大的特性

总之一个树堆包括了树和堆的特性:

  • 堆: 某个节点的权值一定大于所有子节点
  • 树: 左节点的值小于根节点,右节点的值大于根节点

在这里堆看起来有点多余,我们明明只需要树的特性,即边插入边查询的特性,为什么还需要堆呢?

其实堆这里是为了防止树退化成链表,具体可以看看这个: 树高的证明。反之我是看不懂,用就完了。。。

这里是声明的代码:

type TreapNode struct { value, count, weight int left, right *TreapNode } type Treap struct { root *TreapNode // 用于后面的查询 ans, pos int }
go

插入

在插入的时候,我们只需要正常插入即可,最后在创建权值的时候随机 roll 一个权值出来就可以了, 最后利用递归检查节点之前的权值,通过旋转来满足堆的特性。

一般来说, 树有四种旋转的类型: LL、RR、LR 和 RL。在树堆里面, 我们只需要用到 LL 和 RR 即可:

  • LL 旋: 左节点的权值大于当前节点。
  • RR 旋: 右节点的权值大于当前节点。
func (t *Treap) insert(node *TreapNode, val int) { if node.count == 0 { node.value = val node.count = 1 node.weight = int(rand.Float64() * 1e5) // 这里不要直接用 nil,用 count 来判断是否为空,减少心智负担 node.left = &TreapNode{} node.right = &TreapNode{} } else if node.value == val { node.count++ } else if val > node.value { t.insert(node.right, val) if node.right.weight > node.weight { node.rrRotate() } } else { t.insert(node.left, val) if node.left.weight > node.weight { node.llRotate() } } }
go

查询前驱

func (t *Treap) front(node *TreapNode, val int, currentMatch int) int { if node.count == 0 { return currentMatch } if val <= node.value { return t.front(node.left, val, currentMatch) } return t.front(node.right, val, node.value) }
go

查询后继

func (t *Treap) backend(node *TreapNode, val int, currentMatch int) int { if node.count == 0 { return currentMatch } if val >= node.value { return t.backend(node.right, val, currentMatch) } return t.backend(node.left, val, min(currentMatch, node.value)) }
go

拓展

前面的几个实现已经够我们完成上面的题目了,但是为了追求完美,我们可以继续实现其它特性。刷题可以看:P3369 【模板】普通平衡树

删除

删除也同样很简单,我们只需要把目标节点旋转到叶子结点就可以了,在删除时也应该保持堆的特性, 对于左右节点均非空的情况下,和堆一样,应该优先把权值较大的节点旋转上来

虽然这题不需要删除,但还是把代码写一下。

func (t *Treap) delete(node *TreapNode, val int) { if node.count == 0 { return } if node.value == val { if node.count > 1 { node.count-- } else { // assert node.count == 1 if node.left.isNil() && node.right.isNil() { *node = TreapNode{} } else if node.left.isNil() { node.rrRotate() t.delete(node.left, val) } else if node.right.isNil() { node.llRotate() t.delete(node.right, val) } else { if node.left.weight > node.right.weight { node.llRotate() t.delete(node.right, val) } else { node.rrRotate() t.delete(node.left, val) } } } } else if val > node.value { t.delete(node.right, val) } else { t.delete(node.left, val) } node.resize() }
go

如果你想偷懒,在删除时不维护堆的特性也可以,实际并不会影响多少效率,只不过在插入的时候可能需要多旋转几次罢了。

根据排名查找元素

需要注意一下题目没说没有找到的时候要返回 int 最大值,如果你 RE 了大概率是这里的问题。

func (t *Treap) getByRank(node *TreapNode, expected int) int { if node.count == 0 { return math.MaxInt32 } if node.left.size >= expected { return t.getByRank(node.left, expected) } //assert node.left.size < rank if node.left.size >= expected-node.count && node.left.size <= expected-1 { return node.value } return t.getByRank(node.right, expected-node.count-node.left.size) }
go

这里的判断条件有点绕,可以这样想:

假设现在 node.count = 3, expected = 5 并且当前的 node 满足要求。

  • 如果 3 个节点中的第一个匹配,那么左边最多得有 4 个节点,即 expected - 1
  • 如果 3 个节点中的最后一个匹配,那么左边至少得有 2 个节点,即 expected - node.count

最后构成一个区间,当左节点的数量满足区间要求时,当前节点就是我们的答案。再画个抽象点的图:

expected -> | | =================--------------- node.left.size node.count
text

= 表示左边的节点(值小于当前节点的节点), - 表示当前的根节点, | 表示期望的位置,当 | 指向 - 时,表示当前节点符合要求。 此时我们可以想象:不断左右拉伸 =,最后让 | 落到 - 上,形成的一个区间就是满足要求的区间。