From 803bf91694aa3249eeb0944052075d2a41328ba9 Mon Sep 17 00:00:00 2001 From: mban Date: Mon, 13 May 2024 19:21:47 +0900 Subject: [PATCH] =?UTF-8?q?segment=20tree=20=E9=85=8D=E5=88=97=E5=88=9D?= =?UTF-8?q?=E6=9C=9F=E9=95=B7=E3=81=95=E3=82=92=E3=82=B3=E3=83=B3=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=83=A9=E3=82=AF=E3=82=BF=E3=81=A7=E6=B1=BA=E3=82=81?= =?UTF-8?q?=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/kotlin/collections/SegmentTree.kt | 122 ++++++++++++++---- .../kotlin/collections/SegmentTreeTest.kt | 5 +- 2 files changed, 101 insertions(+), 26 deletions(-) diff --git a/src/main/kotlin/collections/SegmentTree.kt b/src/main/kotlin/collections/SegmentTree.kt index 416e05f..a46a936 100644 --- a/src/main/kotlin/collections/SegmentTree.kt +++ b/src/main/kotlin/collections/SegmentTree.kt @@ -1,38 +1,112 @@ package collections -// 演算、(左)単位元 -class SegmentTree(private val op: (T, T) -> T, private val id: T) { - companion object { - // 制約に合うように適宜変えて下さい - const val size = 1 shl 21 - } +class SegmentTree(val n: Int, private val op: (T, T) -> T, private val id: T) { - @Suppress("UNCHECKED_CAST") - private val array = Array(size * 2) { id } as Array + private val array: MutableList + private var size: Int = 1 - operator fun get(i: Int) = array[i + size] - operator fun set(i: Int, n: T) = update(i, n) - - fun update(i: Int, n: T) { - var index = i + size - array[index] = n - while (index > 1) { - index /= 2 - array[index] = op(array[index * 2], array[index * 2 + 1]) + init { + while (size < n) { + size *= 2 } + array = MutableList(2 * size) { id } } - private fun query(left: Int, right: Int, k: Int, l: Int, r: Int): T { - if (left <= l && r <= right) { - return array[k] + constructor(a: Array, op: (T, T) -> T, id: T) : this(a.size, op, id) { + for (i in 0 until n) { + array[i + size] = a[i] } + for (i in size - 1 downTo 1) { + array[i] = op(array[2 * i], array[2 * i + 1]) + } + } + + operator fun set(i: Int, item: T) { + assert(i in 0 until n) + var ptr = i + size + array[ptr] = item + while (ptr > 1) { + ptr /= 2 + array[ptr] = op(array[2 * ptr], array[2 * ptr + 1]) + } + } - if (r <= left || right <= l) { - return id + operator fun get(i: Int) = array[i + size] + + /** + * op(a[l],a[l+1],...a[r-1])を求める + */ + fun query(l: Int, r: Int): T { + assert(l in 0 until n) + assert(r in l..n) + var sml = id + var smr = id + var left = l + size + var right = r + size + while (left < right) { + if ((left and 1) != 0) sml = op(sml, array[left++]) + if ((right and 1) != 0) smr = op(array[--right], smr) + left = left shr 1 + right = right shr 1 } + return op(sml, smr) + } + - return op(query(left, right, k * 2, l, (l + r) / 2), query(left, right, k * 2 + 1, (l + r) / 2, r)) + fun all(): T = array[1] + + /** + * f(op(a[l],a[l+1],...a[r-1])) = trueとなる最大のrを返す + */ + fun maxRight(l: Int, f: (T) -> Boolean): Int { + assert(l in 0..n) + assert(f(id)) + if (l == n) return n + var left = l + size + var sm = id + + do { + while (left % 2 == 0) left = left shr 1 + if (!f(op(sm, array[left]))) { + while (left < size) { + left = left shl 1 + if (f(op(sm, array[left]))) { + sm = op(sm, array[left]) + left++ + } + } + return left - size + } + sm = op(sm, array[left]) + left++ + } while ((left and -left) != left) + return n } - fun query(left: Int, right: Int) = query(left, right, 1, 0, size) + /** + * f(op(a\[l\] ,a[l+1],...a[r-1])) = trueとなる最小のlを返します + */ + fun minLeft(r: Int, f: (T) -> Boolean): Int { + assert(r in 0..n) + assert(f(id)) + if (r == 0) return 0 + var right = r + size + var sm = id + do { + right-- + while (right > 1 && (right % 2 != 0)) right = right shr 1 + if (!f(op(array[right], sm))) { + while (right < size) { + right = (2 * right + 1) + if (f(op(array[right], sm))) { + sm = op(array[right], sm) + right-- + } + } + return right + 1 - size + } + sm = op(array[right], sm) + } while ((right and -right) != right) + return 0 + } } \ No newline at end of file diff --git a/src/test/kotlin/collections/SegmentTreeTest.kt b/src/test/kotlin/collections/SegmentTreeTest.kt index 355dda3..bdd1990 100644 --- a/src/test/kotlin/collections/SegmentTreeTest.kt +++ b/src/test/kotlin/collections/SegmentTreeTest.kt @@ -7,7 +7,7 @@ import kotlin.test.* class SegmentTreeTest { @Test fun randomTest() { - val st = SegmentTree({ l, r -> l + r }, 0) + val st = SegmentTree(1000, { l, r -> l + r }, 0) val array = Array(1000) { 0 } val random = Random(0) for (i in 1..100000) { @@ -20,6 +20,7 @@ class SegmentTreeTest { array[index] = num st[index] = num } + 1 -> { // 区間演算 val left = random.nextInt(1000) @@ -39,7 +40,7 @@ class SegmentTreeTest { @Test fun test1() { - val st = SegmentTree({ l, r -> l + r }, 0) + val st = SegmentTree(5, { l, r -> l + r }, 0) st[0] = 4 st[1] = 3 st[2] = 2