线段树:高效处理区间查询与更新的数据结构

什么是线段树?

线段树是一种二叉树数据结构,用于存储区间或线段的信息。它能够在O(log n)时间复杂度内完成区间查询和单点/区间更新操作,如区间和、区间最小值/最大值、区间乘积等。

基本特性

  • 完全二叉树:线段树通常实现为完全二叉树
  • 存储区间信息:每个节点存储一个区间的某种聚合信息
  • 高效操作:查询和更新操作的时间复杂度均为O(log n)
  • 空间复杂度:O(n),需要2n-1个节点存储n个元素的数组

线段树的结构

线段树的每个节点代表一个区间[l, r]:

  • 叶子节点:代表单个元素的区间[i, i]
  • 内部节点:代表合并的子区间,通常为左右子节点的区间合并
示例:数组[1, 3, 5, 7, 9, 11]的线段树(区间和)

            [0,5](36)
           /        \
      [0,2](9)      [3,5](27)
      /     \        /     \
 [0,1](4) [2,2](5) [3,4](16) [5,5](11)
  /   \             /   \
[0,0](1) [1,1](3) [3,3](7) [4,4](9)

线段树的构建

构建线段树是一个自底向上的过程:

  1. 从叶子节点开始,每个叶子节点对应数组中的一个元素
  2. 递归地构建父节点,合并子节点的信息
def build(node, l, r, arr, tree):
    if l == r:
        tree[node] = arr[l]
        return
    mid = (l + r) // 2
    build(2*node+1, l, mid, arr, tree)
    build(2*node+2, mid+1, r, arr, tree)
    tree[node] = tree[2*node+1] + tree[2*node+2]  # 以区间和为例

线段树的查询

查询区间[q_l, q_r]的信息:

  1. 如果当前节点区间完全包含在查询区间内,直接返回节点值
  2. 如果与查询区间无交集,返回不影响结果的值(如求和的0,最小值的∞)
  3. 否则递归查询左右子树并合并结果
def query(node, l, r, q_l, q_r, tree):
    if r < q_l or l > q_r:  # 无交集
        return 0
    if q_l <= l and r <= q_r:  # 完全包含
        return tree[node]
    mid = (l + r) // 2
    left = query(2*node+1, l, mid, q_l, q_r, tree)
    right = query(2*node+2, mid+1, r, q_l, q_r, tree)
    return left + right

线段树的更新

单点更新

  1. 找到对应的叶子节点
  2. 更新叶子节点的值
  3. 递归向上更新所有受影响的父节点
def update(node, l, r, idx, val, tree):
    if l == r:
        tree[node] = val
        return
    mid = (l + r) // 2
    if idx <= mid:
        update(2*node+1, l, mid, idx, val, tree)
    else:
        update(2*node+2, mid+1, r, idx, val, tree)
    tree[node] = tree[2*node+1] + tree[2*node+2]

区间更新(延迟传播)

对于区间更新操作,可以使用延迟传播(Lazy Propagation)技术来优化:

  1. 标记需要更新但尚未实际执行的区间
  2. 在查询或进一步更新时传播这些标记

线段树的应用

线段树可以解决多种区间操作问题:

  1. 区间求和:查询任意区间的元素和
  2. 区间最值:查询区间最小值/最大值
  3. 区间统计:统计满足特定条件的元素数量
  4. 区间覆盖:批量修改区间内的元素值
  5. 逆序对计数:统计数组中的逆序对数量
  6. 扫描线算法:用于计算几何中的矩形面积并等问题

线段树的变种

  1. zkw线段树:非递归实现,效率更高
  2. 动态开点线段树:节省空间,适用于稀疏数据
  3. 二维线段树:处理二维平面上的区间查询
  4. 持久化线段树:支持历史版本查询

代码示例(Python实现区间和线段树)

class SegmentTree:
    def __init__(self, data):
        self.n = len(data)
        self.size = 1
        while self.size < self.n:
            self.size <<= 1
        self.tree = [0] * (2 * self.size)
        # 初始化叶子节点
        for i in range(self.n):
            self.tree[self.size + i] = data[i]
        # 构建内部节点
        for i in range(self.size - 1, 0, -1):
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]

    def update(self, index, value):
        pos = self.size + index
        self.tree[pos] = value
        pos >>= 1
        while pos >= 1:
            new_val = self.tree[2 * pos] + self.tree[2 * pos + 1]
            if self.tree[pos] == new_val:
                break
            self.tree[pos] = new_val
            pos >>= 1

    def query(self, l, r):
        res = 0
        l += self.size
        r += self.size
        while l <= r:
            if l % 2 == 1:
                res += self.tree[l]
                l += 1
            if r % 2 == 0:
                res += self.tree[r]
                r -= 1
            l >>= 1
            r >>= 1
        return res

代码示例(Golang实现区间和线段树)

package main

import (
	"fmt"
)

type SegmentTree struct {
	data []int
	tree []int
	size int
}

// NewSegmentTree 创建线段树
func NewSegmentTree(data []int) *SegmentTree {
	n := len(data)
	st := &SegmentTree{
		data: data,
		size: 1,
	}
	
	// 计算线段树大小(最接近且大于等于n的2的幂次)
	for st.size < n {
		st.size <<= 1
	}
	
	// 初始化线段树
	st.tree = make([]int, 2*st.size)
	
	// 填充叶子节点
	for i := 0; i < n; i++ {
		st.tree[st.size+i] = data[i]
	}
	
	// 构建内部节点
	for i := st.size - 1; i > 0; i-- {
		st.tree[i] = st.tree[2*i] + st.tree[2*i+1]
	}
	
	return st
}

// Update 更新指定位置的值
func (st *SegmentTree) Update(index int, value int) {
	pos := st.size + index
	st.tree[pos] = value
	
	// 向上更新父节点
	for pos > 1 {
		pos >>= 1
		newVal := st.tree[2*pos] + st.tree[2*pos+1]
		if st.tree[pos] == newVal {
			break // 如果没有变化,可以提前终止
		}
		st.tree[pos] = newVal
	}
}

// Query 查询区间[l, r]的和
func (st *SegmentTree) Query(l, r int) int {
	res := 0
	l += st.size
	r += st.size
	
	for l <= r {
		// 如果l是右子节点,单独处理
		if l%2 == 1 {
			res += st.tree[l]
			l++
		}
		
		// 如果r是左子节点,单独处理
		if r%2 == 0 {
			res += st.tree[r]
			r--
		}
		
		// 移动到父节点
		l >>= 1
		r >>= 1
	}
	
	return res
}

func main() {
	data := []int{1, 3, 5, 7, 9, 11}
	st := NewSegmentTree(data)
	
	fmt.Println("初始线段树:")
	fmt.Println(st.tree)
	
	fmt.Println("\n查询区间[1,4]的和:", st.Query(1, 4))
	
	fmt.Println("\n更新索引2的值为10")
	st.Update(2, 10)
	fmt.Println("更新后的线段树:", st.tree)
	
	fmt.Println("\n查询区间[0,5]的和:", st.Query(0, 5))
	fmt.Println("查询区间[2,3]的和:", st.Query(2, 3))
}

总结

线段树是一种功能强大且灵活的数据结构,特别适合处理各种区间查询和更新问题。虽然实现起来比简单的前缀数组或差分数组更复杂,但在需要频繁混合查询和更新的场景下,线段树的性能优势非常明显。

滚动至顶部