little star's memory

競プロ、なぞなぞ、その他

KotlinのBIT、セグメント木、遅延セグメント木

AtCoder Library (ACL)が公開されました。公式で対応しているのはC++のみですが、有志により他言語への翻訳が進められています。Kotlinへの翻訳プロジェクトもあります。(追記: 非公開になっていました)

自分もKotlinへの翻訳を試みましたが、データ構造・アルゴリズムの理解もプログラミングの理解も浅いため全く進みませんでした…。セグメント木だけはなんとか作ったので、ACL Contest 1の前に公開しておきます。使用するのは自由ですが、バグっていても責任は負いません。

BIT

ACLではFenwick Treeと呼ばれていますが、BITと呼ぶことにしました(短いので)。

class BIT(private val n: Int) {
    private val data = LongArray(n) { 0 }

    fun add(i: Int, x: Long) {
        var p = i + 1
        while (p <= n) {
            data[p - 1] += x
            p += p and -p
        }
    }

    fun sum(l: Int, r: Int): Long {
        return sum(r) - sum(l)
    }

    private fun sum(i: Int): Long {
        var r = i
        var s = 0L
        while(r > 0) {
            s += data[r - 1]
            r -= r and -r
        }
        return s
    }
}

ACLではint, uint, ll, ull, modintに対応していますが、上のコードはLongにしか対応していません。

使い方

  • add(p, x): a[p]にxを加えます。
  • sum(l,r): a[l]+...+a[r-1]を求めます。

使用例

セグメント木

class SegTree<Monoid>(private val n: Int, private val op: (Monoid, Monoid) -> Monoid, private val e: Monoid) : Iterable<Monoid> {
    private var d: MutableList<Monoid>
    private var size = 1
    private var log = 0

    init {
        while (size < n) {
            size *= 2
            log++
        }
        d = MutableList(2 * size) { e }
    }

    constructor(a: Array<Monoid>, op: (Monoid, Monoid) -> Monoid, e: Monoid) : this(a.count(), op, e) {
        for (i in 0 until n) {
            d[size + i] = a[i]
        }
        for (i in size - 1 downTo 1) update(i)
    }

    operator fun set(p: Int, x: Monoid) {
        if (p !in 0 until n) {
            throw Exception()
        }
        val p1 = p + size
        d[p1] = x
        for (i in 1..log) {
            update(p1 shr i)
        }
    }

    operator fun get(p: Int): Monoid {
        if (p !in 0 until n) {
            throw Exception()
        }
        return d[p + size]
    }

    fun prod(l: Int, r: Int): Monoid {
        if (l !in 0..r || r !in l..n) {
            throw Exception()
        }
        var sml = e
        var smr = e
        var l1 = l + size
        var r1 = r + size
        while (l1 < r1) {
            if (l1 and 1 != 0) sml = op(sml, d[l1++])
            if (r1 and 1 != 0) smr = op(d[--r1], smr)
            l1 /= 2
            r1 /= 2
        }
        return op(sml, smr)
    }

    fun allProd(): Monoid {
        return d[1]
    }

    fun maxRight(l: Int, f: (Monoid) -> Boolean): Int {
        if (l !in 0..n || !f(e)) {
            throw Exception()
        }
        if (l == n) return n
        var l1 = l + size
        var sm = e
        do {
            while (l1 % 2 == 0) l1 /= 2
            if (!f(op(sm, d[l1]))) {
                while (l1 < size) {
                    l1 *= 2
                    if (f(op(sm, d[l1]))) {
                        sm = op(sm, d[l1])
                        l1++
                    }
                }
                return l1 - size
            }
            sm = op(sm, d[l1])
            l1++
        } while ((l1 and -l1) != l1)
        return n
    }

    fun minLeft(r: Int, f: (Monoid) -> Boolean): Int {
        if (r !in 0..n || !f(e)) {
            throw Exception()
        }
        if (r == 0) return 0
        var r1 = r + size
        var sm = e
        do {
            r1--
            while (r1 > 1 && r1 % 2 != 0) r1 /= 2
            if (!f(op(d[r1], sm))) {
                while (r1 < size) {
                    r1 = 2 * r1 + 1
                    if (f(op(d[r1], sm))) {
                        sm = op(d[r1], sm)
                        r1--
                    }
                }
                return r1 + 1 - size
            }
            sm = op(d[r1], sm)
        } while ((r1 and -r1) != r1)
        return 0
    }

    private fun update(k: Int) {
        d[k] = op(d[2 * k], d[2 * k + 1])
    }

    override fun iterator(): Iterator<Monoid> {
        return (0 until n).map { this[it] }.iterator()
    }
}

使い方

  • コンストラクタ: 長さn, 演算子op, 単位元eを渡すと、初期値eの長さnの数列ができます。数列a, 演算子op, 単位元eを渡すと、初期値がaの要素になります。
  • set: a[p] = xと書くことができます。
  • get: a[p]と書くことで値を取得できます。
  • prod(l, r): op(a[l],...,a[r-1])を計算します。
  • allProd(): op(a[0],...,a[n-1])をO(1)で求めます。
  • maxRight(l, f): fが単調のとき、f(op(a[l], a[l + 1], ..., a[r - 1])) = trueとなるような最大のrを返します。
  • minLeft(r, f): fが単調のとき、f(op(a[l], a[l + 1], ..., a[r - 1])) = trueとなるような最小のlを返します。

備考

単位元の扱いがACLと異なっています。ACLではeはMonoidの値を返す引数なし関数ですが、上のコードではeはMonoidの元としました。(あまり違いがわかっていない)

ACLにないですが、イテレータを実装しました。joinToString()が使えるので、デバッグに便利かもしれません。

使用例

遅延セグメント木

class LazySegTree<Monoid, F>(private val n: Int, private val op: (Monoid, Monoid) -> Monoid, private val e: Monoid, private val mapping: (F, Monoid) -> Monoid, private val composition: (F, F) -> F, private val id: F) : Iterable<Monoid> {
    private var d: MutableList<Monoid>
    private var lz: MutableList<F>
    private var size = 1
    private var log = 0

    init {
        while (size < n) {
            size *= 2
            log++
        }
        d = MutableList(2 * size) { e }
        lz = MutableList(size) { id }
    }

    constructor(a: Array<Monoid>, op: (Monoid, Monoid) -> Monoid, e: Monoid, mapping: (F, Monoid) -> Monoid, composition: (F, F) -> F, id: F) : this(a.count(), op, e, mapping, composition, id) {
        for (i in 0 until n) {
            d[size + i] = a[i]
        }
        for (i in size - 1 downTo 1) {
            update(i)
        }
    }

    operator fun set(p: Int, x: Monoid) {
        if (p !in 0 until n) {
            throw Exception()
        }
        val p1 = p + size
        for (i in log downTo 1) {
            push(p1 shr i)
        }
        d[p1] = x
        for (i in 1..log) {
            update(p1 shr i)
        }
    }

    operator fun get(p: Int): Monoid {
        if (p !in 0 until n) {
            throw Exception()
        }
        val p1 = p + size
        for (i in log downTo 1) {
            push(p1 shr i)
        }
        return d[p1]
    }

    fun prod(l: Int, r: Int): Monoid {
        if (l !in 0..r || r !in l..n) {
            throw Exception()
        }
        if (l == r) return e
        var l1 = l + size
        var r1 = r + size
        for (i in log downTo 1) {
            if ((l1 shr i) shl i != l1) push(l1 shr i)
            if ((r1 shr i) shl i != r1) push(r1 shr i)
        }
        var sml = e
        var smr = e
        while (l1 < r1) {
            if (l1 and 1 != 0) sml = op(sml, d[l1++])
            if (r1 and 1 != 0) smr = op(d[--r1], smr)
            l1 /= 2
            r1 /= 2
        }
        return op(sml, smr)
    }

    fun allProd(): Monoid {
        return d[1]
    }

    fun apply(p: Int, f: F) {
        if (p !in 0 until n) {
            throw Exception()
        }
        val p1 = p + size
        for (i in log downTo 1) {
            push(p1 shr i)
        }
        d[p1] = mapping(f, d[p1])
        for (i in 1..log) {
            update(p1 shr i)
        }
    }

    fun apply(l: Int, r: Int, f: F) {
        if (l !in 0..r || r !in l..n) {
            throw Exception()
        }
        if (l == r) return
        var l1 = l + size
        var r1 = r + size
        for (i in log downTo 1) {
            if ((l1 shr i) shl i != l1) push(l1 shr i)
            if ((r1 shr i) shl i != r1) push((r1 - 1) shr i)
        }
        val l2 = l1
        val r2 = r1
        while (l1 < r1) {
            if (l1 and 1 != 0) allApply(l1++, f)
            if (r1 and 1 != 0) allApply(--r1, f)
            l1 /= 2
            r1 /= 2
        }
        l1 = l2
        r1 = r2
        for (i in 1..log) {
            if ((l1 shr i) shl i != l1) update(l1 shr i)
            if ((r1 shr i) shl i != r1) update((r1 - 1) shr i)
        }
    }

    fun maxRight(l: Int, g: (Monoid) -> Boolean): Int {
        if (l !in 0..n || !g(e)) {
            throw Exception()
        }
        if (l == n) return n
        var l1 = l + size
        for (i in log downTo 1) push(l1 shr i)
        var sm = e
        do {
            while (l1 % 2 == 0) l1 /= 2
            if (!g(op(sm, d[l1]))) {
                while (l1 < size) {
                    push(l1)
                    l1 *= 2
                    if (g(op(sm, d[l1]))) {
                        sm = op(sm, d[l1])
                        l1++
                    }
                }
                return l1 - size
            }
            sm = op(sm, d[l1])
            l1++
        } while ((l1 and -l1) != l1)
        return n
    }

    fun minLeft(r: Int, g: (Monoid) -> Boolean): Int {
        if (r !in 0..n || !g(e)) {
            throw Exception()
        }
        if (r == 0) return 0
        var r1 = r + size
        for (i in log downTo 1) push((r1 - 1) shr i)
        var sm = e
        do {
            r1--
            while (r1 > 1 && r1 % 2 != 0) r1 /= 2
            if (!g(op(d[r1], sm))) {
                while (r1 < size) {
                    push(r1)
                    r1 = 2 * r1 + 1
                    if (g(op(d[r1], sm))) {
                        sm = op(d[r1], sm)
                        r1--
                    }
                }
                return r1 + 1 - size
            }
            sm = op(d[r1], sm)
        } while ((r1 and -r1) != r1)
        return 0
    }

    private fun update(k: Int) {
        d[k] = op(d[2 * k], d[2 * k + 1])
    }

    private fun allApply (k: Int, f: F) {
        d[k] = mapping(f, d[k])
        if (k < size) {
            lz[k] = composition(f, lz[k])
        }
    }

    private fun push(k: Int) {
        allApply(2 * k, lz[k])
        allApply(2 * k + 1, lz[k])
        lz[k] = id
    }

    override fun iterator(): Iterator<Monoid> {
        return (0 until n).map { this[it] }.iterator()
    }
}

使い方

  • コンストラクタ: 長さn, 演算子op, 単位元e, $(f,x)\mapsto f(x)$を表すmapping, 写像の合成を表すcomposition, 恒等写像idを渡します。初期値はeになります。nの代わりに数列aを渡すこともできます。
  • set: a[p] = xと書くことができます。
  • get: a[p]と書くことで値を取得できます。
  • prod(l, r): op(a[l],...,a[r-1])を計算します。
  • allProd(): op(a[0],...,a[n-1])をO(1)で求めます。
  • apply(p, f): a[p]をf(a[p])に置き換えます。
  • apply(l, r, f): i=l,...,r-1に対し、a[i]をf(a[i])に置き換えます。
  • maxRight(l, f): fが単調のとき、f(op(a[l], a[l + 1], ..., a[r - 1])) = trueとなるような最大のrを返します。
  • minLeft(r, f): fが単調のとき、f(op(a[l], a[l + 1], ..., a[r - 1])) = trueとなるような最小のlを返します。

備考

セグメント木と同様にイテレータを実装していますが、とても重いと思うのでデバッグ以外で使うのはやめたほうがいいかもしれません。

使用例

AtCoder Library Practice Contest K - Range Affine Range SumはTLEしてしまうのでACできていません。