Segment Trees - Introduction

·

2 min read

#include <bits/stdc++.h>
using namespace std;

int n;
vector<int> a;
vector<int> segTree;

// build function to build the segment tree

void build(int i, int j, int k) {

    // i -> starting index of a part of the array
    // j -> ending index of a part of the array
    // k -> index of the segment tree

    // time complexity analysis
    // considering the tree of function calls, the number of function calls
    // at each level become twice of the previous level, i.e.,
    // 1 + 2 + 4 + ... + n
    // or we can say that from the last level it gets halved i.e.,
    // n + n/2 + ... + 1 = 2n = O(n)

    if(i==j) {
        segTree[k] = a[i];
        return;
    }

    build(i, i+(j-i)/2, 2*k);
    build(i+(j-i)/2+1, j, 2*k+1);

    segTree[k] = segTree[2*k] + segTree[2*k+1];

}

// update function to update a given point in the segment tree

void update(int i, int j, int k, int pos, int x) {

    // i -> starting index of the search range
    // j -> ending index of the search range
    // k -> current segment tree index
    // pos -> index where change is made
    // x -> new value

    // time complexity analysis
    // at each node of the segment tree the search space is halved so O(logn)

    if(i==j) {

        if(i==pos) {
            a[pos] = x;
            segTree[k] = x;
        }

        return;
    }

    if(i<=pos && pos<=i+(j-i)/2) {
        update(i, i+(j-i)/2, 2*k, pos, x);
    } else {
        update(i+(j-i)/2+1, j, 2*k+1, pos, x);
    }

    segTree[k] = segTree[2*k] + segTree[2*k+1];

}

// query function to return sum in the range l to r

int query(int i, int j, int k, int l, int r) {

    // i -> starting index of the search range
    // j -> ending index of the search range
    // k -> current segment tree index
    // l -> start index of query
    // r -> end index of query

    // time complexity: O(logn)

    // complete overlap condition

    if(i>=l && j<=r) 
        return segTree[k];

    // disjoint set condition

    if(l>j || r<i)
        return 0;

    // otherwise, some overlap condition

    int left = query(i, i+(j-i)/2, 2*k, l, r);
    int right = query(i+(j-i)/2+1, j, 2*k+1, l, r);

    return left + right;

}

int main() {

    cin >> n;

    a.resize(n+1);
    segTree.resize(4*n, 0);

    for(int i=1; i<=n; i++)
        cin >> a[i];

    build(1, n, 1);

    return 0;
}