diff --git a/README.md b/README.md index 6ab51882..5065c3e0 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Gradle is used for development. - [Monotonic Queue](src/main/java/dataStructures/queue/monotonicQueue) - Segment Tree - [Stack](src/main/java/dataStructures/stack) +- [Segment Tree](src/main/java/dataStructures/segmentTree) - [Trie](src/main/java/dataStructures/trie) ## Algorithms @@ -86,6 +87,7 @@ Gradle is used for development. * [AVL-tree](src/main/java/dataStructures/avlTree) * [Trie](src/main/java/dataStructures/trie) * [B-Tree](src/main/java/dataStructures/bTree) + * [Segment Tree](src/main/java/dataStructures/segmentTree) (Not covered in CS2040s but useful!) * Red-Black Tree (Not covered in CS2040s but useful!) * [Orthogonal Range Searching](src/main/java/algorithms/orthogonalRangeSearching) * Interval Trees (**WIP**) diff --git a/docs/assets/images/SegmentTree.png b/docs/assets/images/SegmentTree.png new file mode 100644 index 00000000..44517df1 Binary files /dev/null and b/docs/assets/images/SegmentTree.png differ diff --git a/src/main/java/dataStructures/heap/README.md b/src/main/java/dataStructures/heap/README.md index 7ac53e1d..1f09cd76 100644 --- a/src/main/java/dataStructures/heap/README.md +++ b/src/main/java/dataStructures/heap/README.md @@ -27,6 +27,20 @@ That said, in practice, the array-based implementation of a heap often provides former, in cache efficiency and memory locality. This is due to its contiguous memory layout. As such, the implementation shown here is a 0-indexed array-based heap. +#### Obtain index representing child nodes +Suppose the parent node is captured at index *i* of the array (1-indexed). +**1-indexed**:
+Left Child: *i* x 2
+Right Child: *i* x 2 + 1
+ +The 1-indexed calculation is intuitive. So, when dealing with 0-indexed representation (as in our implementation), +one option is to convert 0-indexed to 1-indexed representation, do the above calculations, and revert.
+(Note: Now, we assume parent node is captured at index *i* (0-indexed)) + +**0-indexed**:
+Left Child: (*i* + 1) x 2 - 1 = *i* x 2 + 1
+Right Child: (*i* + 1) x 2 + 1 - 1 = *i* x 2 + 2
+ ### Relevance of increaseKey and decreaseKey operations The decision not to include explicit "decrease key" and "increase key" operations in the standard implementations of diff --git a/src/main/java/dataStructures/segmentTree/README.md b/src/main/java/dataStructures/segmentTree/README.md new file mode 100644 index 00000000..e531df78 --- /dev/null +++ b/src/main/java/dataStructures/segmentTree/README.md @@ -0,0 +1,89 @@ +# Segment Tree + +## Background +Segment Trees are primarily used to solve problems that require answers to queries on intervals of an array +with the possibility of modifying the array elements. +These queries could be finding the sum, minimum, or maximum in a subarray, or similar aggregated results. + +![Segment Tree](../../../../../docs/assets/images/SegmentTree.png) + +### Structure +(Note: See below for a brief description of the array-based implementation of a segment tree) + +A Segment Tree for an array of size *n* is a binary tree that stores information about segments of the array. +Each node in the tree represents an interval of the array, with the root representing the entire array. +The structure satisfies the following properties: +1. Leaf Nodes: Each leaf node represents a single element of the array. +2. Internal Nodes: Each internal node represents the sum of the values of its children +(which captures the segment of the array). Summing up, this node captures the whole segment. +3. Height: The height of the Segment Tree is O(log *n*), making queries and updates efficient. + +## Complexity Analysis +**Time**: O(log(n)) in general for query and update operations, +except construction which takes O(nlogn) + +**Space**: O(n), note for an array-based implementation, the array created should have size 4n (explained later) + +where n is the number of elements in the array. + +## Operations +### Construction +The construction of a Segment Tree starts with the root node representing the entire array and +recursively dividing the array into two halves until each segment is reduced to a single element. +This process is a divide-and-conquer strategy: +1. Base Case: If the current segment of the array is reduced to a single element, create a leaf node. +2. Recursive Case: Otherwise, split the array segment into two halves, construct the left and right children, +and then merge their results to build the parent node. + +This takes O(nlogn). logn in depth, and will visit each leaf node (number of leaf nodes could be roughly 2n) once. + +### Querying +To query an interval, say to find the sum of elements in the interval (L, R), +the tree is traversed starting from the root: +1. If the current node's segment is completely within (L, R), its value is part of the answer. +2. If the current node's segment is completely outside (L, R), it is ignored. +3. If the current node's segment partially overlaps with (L, R), the query is recursively applied to its children. + +This approach ensures that each level of the tree is visited only once, time complexity of O(logn). + +### Updating +Updating an element involves changing the value of a leaf node and then propagating this change up to the root +to ensure the tree reflects the updated array. +This is done by traversing the path from the leaf node to the root +and updating each node along this path (update parent to the sum of its children). + +This can be done in O(logn). + +## Array-based Segment Tree +The array-based implementation of a Segment Tree is an efficient way to represent the tree in memory, especially +since a Segment Tree is a complete binary tree. +This method utilizes a simple array where each element of the array corresponds to a node in the tree, +including both leaves and internal nodes. + +### Why 4n space +The size of the array needed to represent a Segment Tree for an array of size *n* is 2*2^ceil(log2(*n*)) - 1. +We do 2^(ceil(log2(*n*))) because *n* might not be a perfect power of 2, +**so we expand the array size to the next power of 2**. +This adjustment ensures that each level of the tree is fully filled except possibly for the last level, +which is filled from left to right. + +**BUT**, 2^(ceil(log2(*n*))) seems overly-complex. To ensure we have sufficient space, we can just consider 2*n +because 2*n >= 2^(ceil(log2(*n*))). +Now, these 2n nodes can be thought of as the 'leaf' nodes (or more precisely, an upper-bound). To account for the +intermediate nodes, we use the property that for a complete binary that is fully filled, the number of leaf nodes += number of intermediate nodes (recall: sum i -> 0 to n-1 of 2^i = 2^n). So we create an array of size 2n * 2 = 4n to +guarantee we can house the entire segment tree. + +### Obtain index representing child nodes +Suppose the parent node is captured at index *i* of the array (1-indexed). +**1-indexed**:
+Left Child: *i* x 2
+Right Child: *i* x 2 + 1
+ +The 1-indexed calculation is intuitive. So, when dealing with 0-indexed representation (as in our implementation), +one option is to convert 0-indexed to 1-indexed representation, do the above calculations, and revert.
+(Note: Now, we assume parent node is captured at index *i* (0-indexed)) + +**0-indexed**:
+Left Child: (*i* + 1) x 2 - 1 = *i* x 2 + 1
+Right Child: (*i* + 1) x 2 + 1 - 1 = *i* x 2 + 2
diff --git a/src/main/java/dataStructures/segmentTree/SegmentTree.java b/src/main/java/dataStructures/segmentTree/SegmentTree.java new file mode 100644 index 00000000..f860d093 --- /dev/null +++ b/src/main/java/dataStructures/segmentTree/SegmentTree.java @@ -0,0 +1,115 @@ +package dataStructures.segmentTree; + +/** + * Implementation of a Segment Tree. Uses SegmentTreeNode as a helper node class. + */ +public class SegmentTree { + private SegmentTreeNode root; + private int[] array; + + /** + * Helper node class. Used internally. + */ + private class SegmentTreeNode { + private SegmentTreeNode leftChild; // left child + private SegmentTreeNode rightChild; // right child + private int start; // start idx of range captured + private int end; // end idx of range captured + private int sum; // sum of all elements between start and end index inclusive + + /** + * Constructor + * @param leftChild + * @param rightChild + * @param start + * @param end + * @param sum + */ + public SegmentTreeNode(SegmentTreeNode leftChild, SegmentTreeNode rightChild, int start, int end, int sum) { + this.leftChild = leftChild; + this.rightChild = rightChild; + this.start = start; + this.end = end; + this.sum = sum; + } + } + + /** + * Constructor. + * @param nums + */ + public SegmentTree(int[] nums) { + root = buildTree(nums, 0, nums.length - 1); + array = nums; + } + + private SegmentTreeNode buildTree(int[] nums, int start, int end) { + if (start == end) { + return new SegmentTreeNode(null, null, start, end, nums[start]); + } + int mid = start + (end - start) / 2; + SegmentTreeNode left = buildTree(nums, start, mid); + SegmentTreeNode right = buildTree(nums, mid + 1, end); + return new SegmentTreeNode(left, right, start, end, left.sum + right.sum); + } + + /** + * Queries the sum of all values in the specified range. + * @param leftEnd + * @param rightEnd + * @return the sum. + */ + public int query(int leftEnd, int rightEnd) { + return query(root, leftEnd, rightEnd); + } + + private int query(SegmentTreeNode node, int leftEnd, int rightEnd) { + // this is the case when: + // start end + // range query: ^ ^ --> so simply capture the sum at this node! + if (leftEnd <= node.start && node.end <= rightEnd) { + return node.sum; + } + int rangeSum = 0; + int mid = node.start + (node.end - node.start) / 2; + // Consider the 3 possible kinds of range queries + // start mid end + // poss 1: ^ ^ + // poss 2: ^ ^ + // poss 3: ^ ^ + if (leftEnd <= mid) { + rangeSum += query(node.leftChild, leftEnd, Math.min(rightEnd, mid)); // poss1 or poss2 + } + if (mid + 1 <= rightEnd) { + rangeSum += query(node.rightChild, Math.max(leftEnd, mid + 1), rightEnd); // poss2 or poss3 + } + return rangeSum; + } + + /** + * Updates the segment tree based on updates to the array at the specified index with the specified value. + * @param idx + * @param val + */ + public void update(int idx, int val) { + if (idx > array.length) { + return; + } + array[idx] = val; + update(root, idx, val); + } + + private void update(SegmentTreeNode node, int idx, int val) { + if (node.start == node.end && node.start == idx) { + node.sum = val; // node is holding a single value; now updated + return; + } + int mid = node.start + (node.end - node.start) / 2; + if (idx <= mid) { + update(node.leftChild, idx, val); + } else { + update(node.rightChild, idx, val); + } + node.sum = node.leftChild.sum + node.rightChild.sum; // propagate updates up + } +} diff --git a/src/main/java/dataStructures/segmentTree/arrayRepresentation/SegmentTree.java b/src/main/java/dataStructures/segmentTree/arrayRepresentation/SegmentTree.java new file mode 100644 index 00000000..a75f656a --- /dev/null +++ b/src/main/java/dataStructures/segmentTree/arrayRepresentation/SegmentTree.java @@ -0,0 +1,105 @@ +package dataStructures.segmentTree.arrayRepresentation; + +/** + * Array-based implementation of a Segment Tree. + */ +public class SegmentTree { + private int[] tree; + private int[] array; + + /** + * Constructor. + * @param nums + */ + public SegmentTree(int[] nums) { + tree = new int[4 * nums.length]; // Need to account for up to 4n nodes. + array = nums; + buildTree(nums, 0, nums.length - 1, 0); + } + + /** + * Builds the tree from the given array of numbers. + * Unlikely before where we capture child nodes in the helper node class, here we capture position of child nodes + * in the array-representation of the tree with an additional variable. + * @param nums + * @param start + * @param end + * @param idx tells us which index of the tree array we are at. + */ + private void buildTree(int[] nums, int start, int end, int idx) { + // recall, each node is a position in the array + // explicitly track which position in the array to fill with idx variable + if (start == end) { + tree[idx] = nums[start]; + return; + } + int mid = start + (end - start) / 2; + int idxLeftChild = (idx + 1) * 2 - 1; // convert from 0-based to 1-based, do computation, then revert + buildTree(nums, start, mid, idxLeftChild); + int idxRightChild = (idx + 1) * 2 + 1 - 1; // convert from 0-based to 1-based, do computation, then revert + buildTree(nums, mid + 1, end, idxRightChild); + tree[idx] = tree[idxLeftChild] + tree[idxRightChild]; + } + + /** + * Queries the sum of all values in the specified range. + * @param leftEnd + * @param rightEnd + * @return the sum. + */ + public int query(int leftEnd, int rightEnd) { + return query(0, 0, array.length - 1, leftEnd, rightEnd); + } + + private int query(int nodeIdx, int startRange, int endRange, int leftEnd, int rightEnd) { + // this is the case when: + // start end + // range query: ^ ^ --> so simply capture the sum at this node! + if (leftEnd <= startRange && endRange <= rightEnd) { + return tree[nodeIdx]; + } + int rangeSum = 0; + int mid = startRange + (endRange - startRange) / 2; + // Consider the 3 possible kinds of range queries + // start mid end + // poss 1: ^ ^ + // poss 2: ^ ^ + // poss 3: ^ ^ + if (leftEnd <= mid) { + int idxLeftChild = (nodeIdx + 1) * 2 - 1; + rangeSum += query(idxLeftChild, startRange, mid, leftEnd, Math.min(rightEnd, mid)); + } + if (mid + 1 <= rightEnd) { + int idxRightChild = (nodeIdx + 1) * 2 + 1 - 1; + rangeSum += query(idxRightChild, mid + 1, endRange, Math.max(leftEnd, mid + 1), rightEnd); + } + return rangeSum; + } + + /** + * Updates the segment tree based on updates to the array at the specified index with the specified value. + * @param idx + * @param val + */ + public void update(int idx, int val) { + if (idx > array.length) { + return; + } + array[idx] = val; + update(0, 0, array.length - 1, idx, val); + } + + private void update(int nodeIdx, int startRange, int endRange, int idx, int val) { + if (startRange == endRange) { + tree[nodeIdx] = val; + return; + } + int mid = startRange + (endRange - startRange) / 2; + if (idx <= mid) { + update(nodeIdx * 2 + 1, startRange, mid, idx, val); + } else { + update(nodeIdx * 2 + 2, mid + 1, endRange, idx, val); + } + tree[nodeIdx] = tree[nodeIdx * 2 + 1] + tree[nodeIdx * 2 + 2]; + } +} diff --git a/src/test/java/dataStructures/segmentTree/SegmentTreeTest.java b/src/test/java/dataStructures/segmentTree/SegmentTreeTest.java new file mode 100644 index 00000000..d187c56b --- /dev/null +++ b/src/test/java/dataStructures/segmentTree/SegmentTreeTest.java @@ -0,0 +1,39 @@ +package dataStructures.segmentTree; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class SegmentTreeTest { + @Test + public void construct_shouldConstructSegmentTree() { + int[] arr1 = new int[] {7, 77, 37, 67, 33, 73, 13, 2, 7, 17, 87, 53}; + SegmentTree tree1 = new SegmentTree(arr1); + assertEquals(arr1[1] + arr1[2] + arr1[3], tree1.query(1, 3)); + assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree1.query(4, 7)); + int sum1 = 0; + for (int i = 0; i < arr1.length; i++) { + sum1 += arr1[i]; + } + assertEquals(sum1, tree1.query(0, arr1.length - 1)); + + + int[] arr2 = new int[] {7, -77, 37, 67, -33, 0, 73, -13, 2, -7, 17, 0, -87, 53, 0}; // some negatives and 0s + SegmentTree tree2 = new SegmentTree(arr1); + assertEquals(arr1[1] + arr1[2] + arr1[3], tree2.query(1, 3)); + assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree2.query(4, 7)); + int sum2 = 0; + for (int i = 0; i < arr1.length; i++) { + sum2 += arr1[i]; + } + assertEquals(sum2, tree2.query(0, arr1.length - 1)); + } + + @Test + public void update_shouldUpdateSegmentTree() { + int[] arr = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + SegmentTree tree = new SegmentTree(arr); + assertEquals(55, tree.query(0, 10)); + tree.update(5, 55); + assertEquals(105, tree.query(0, 10)); + } +} diff --git a/src/test/java/dataStructures/segmentTree/arrayRepresentation/SegmentTreeTest.java b/src/test/java/dataStructures/segmentTree/arrayRepresentation/SegmentTreeTest.java new file mode 100644 index 00000000..132201b8 --- /dev/null +++ b/src/test/java/dataStructures/segmentTree/arrayRepresentation/SegmentTreeTest.java @@ -0,0 +1,42 @@ +package dataStructures.segmentTree.arrayRepresentation; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +/** + * This file is essentially duplicated from the parent. + */ +public class SegmentTreeTest { + @Test + public void construct_shouldConstructSegmentTree() { + int[] arr1 = new int[] {7, 77, 37, 67, 33, 73, 13, 2, 7, 17, 87, 53}; + SegmentTree tree1 = new SegmentTree(arr1); + assertEquals(arr1[1] + arr1[2] + arr1[3], tree1.query(1, 3)); + assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree1.query(4, 7)); + int sum1 = 0; + for (int i = 0; i < arr1.length; i++) { + sum1 += arr1[i]; + } + assertEquals(sum1, tree1.query(0, arr1.length - 1)); + + + int[] arr2 = new int[] {7, -77, 37, 67, -33, 0, 73, -13, 2, -7, 17, 0, -87, 53, 0}; // some negatives and 0s + SegmentTree tree2 = new SegmentTree(arr1); + assertEquals(arr1[1] + arr1[2] + arr1[3], tree2.query(1, 3)); + assertEquals(arr1[4] + arr1[5] + arr1[6] + arr1[7], tree2.query(4, 7)); + int sum2 = 0; + for (int i = 0; i < arr1.length; i++) { + sum2 += arr1[i]; + } + assertEquals(sum2, tree2.query(0, arr1.length - 1)); + } + + @Test + public void update_shouldUpdateSegmentTree() { + int[] arr = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + SegmentTree tree = new SegmentTree(arr); + assertEquals(55, tree.query(0, 10)); + tree.update(5, 55); + assertEquals(105, tree.query(0, 10)); + } +}