diff --git a/src/iterators/bfs.jl b/src/iterators/bfs.jl index 57779243..d45f7793 100644 --- a/src/iterators/bfs.jl +++ b/src/iterators/bfs.jl @@ -35,15 +35,16 @@ end BFSVertexIteratorState `BFSVertexIteratorState` is a struct to hold the current state of iteration -in BFS which is needed for the `Base.iterate()` function. A queue is used to -keep track of the vertices which will be visited during BFS. Since the queue -can contains repetitions of already visited nodes, we also keep track of that -in a `BitVector` so that to skip those nodes. +in BFS which is needed for the `Base.iterate()` function. We use two vectors, +one for the current level nodes and one from the next level nodes to visit +the graph. Since new levels can contains repetitions of already visited nodes, +we also keep track of that in a `BitVector` so as to skip those nodes. """ mutable struct BFSVertexIteratorState visited::BitVector - queue::Vector{Int} - neighbor_idx::Int + curr_level::Vector{Int} + next_level::Vector{Int} + node_idx::Int n_visited::Int end @@ -58,14 +59,17 @@ First iteration to visit vertices in a graph using breadth-first search. function Base.iterate(t::BFSIterator{<:Integer}) visited = falses(nv(t.graph)) visited[t.source] = true - return (t.source, BFSVertexIteratorState(visited, [t.source], 1, 1)) + state = BFSVertexIteratorState(visited, [t.source], Int[], 0, 0) + return Base.iterate(t, state) end function Base.iterate(t::BFSIterator{<:AbstractArray}) visited = falses(nv(t.graph)) - visited[first(t.source)] = true - state = BFSVertexIteratorState(visited, copy(t.source), 1, 1) - return (first(t.source), state) + curr_level = unique(s for s in t.source) + sort!(curr_level) + visited[curr_level] .= true + state = BFSVertexIteratorState(visited, curr_level, Int[], 0, 0) + return Base.iterate(t, state) end """ @@ -74,36 +78,25 @@ end Iterator to visit vertices in a graph using breadth-first search. """ function Base.iterate(t::BFSIterator, state::BFSVertexIteratorState) - graph, visited, queue = t.graph, state.visited, state.queue - while !isempty(queue) - if state.n_visited == nv(graph) - return nothing - end - # we visit the first node in the queue - node_start = first(queue) - if !visited[node_start] - visited[node_start] = true - state.n_visited += 1 - return (node_start, state) - end - # which means we arrive here when the first node was visited. - neigh = outneighbors(graph, node_start) - if state.neighbor_idx <= length(neigh) - node = neigh[state.neighbor_idx] - # we update the idx of the neighbor we will visit, - # if it is already visited, we repeat - state.neighbor_idx += 1 - if !visited[node] - push!(queue, node) - state.visited[node] = true - state.n_visited += 1 - return (node, state) + # we fill nodes in this level + if state.node_idx == length(state.curr_level) + state.n_visited += length(state.curr_level) + state.n_visited == nv(t.graph) && return nothing + @inbounds for node in state.curr_level + for adj_node in outneighbors(t.graph, node) + if !state.visited[adj_node] + push!(state.next_level, adj_node) + state.visited[adj_node] = true + end end - else - # when the first node and its neighbors are visited - # we remove the first node of the queue - popfirst!(queue) - state.neighbor_idx = 1 end + length(state.next_level) == 0 && return nothing + state.curr_level, state.next_level = state.next_level, empty!(state.curr_level) + sort!(state.curr_level) + state.node_idx = 0 end + # we visit all nodes in this level + state.node_idx += 1 + @inbounds node = state.curr_level[state.node_idx] + return (node, state) end diff --git a/test/iterators/bfs.jl b/test/iterators/bfs.jl index c70bdf36..f459eafa 100644 --- a/test/iterators/bfs.jl +++ b/test/iterators/bfs.jl @@ -35,7 +35,14 @@ end end nodes_visited = collect(BFSIterator(g2, [1, 6])) - @test nodes_visited == [1, 2, 3, 6, 5, 7, 4] + levels = ([1, 6], [2, 3, 5, 7], [4]) + @test sort(nodes_visited[1:2]) == sort(levels[1]) + @test sort(nodes_visited[3:6]) == sort(levels[2]) + @test sort(nodes_visited[7:end]) == sort(levels[3]) + nodes_visited = collect(BFSIterator(g2, [8, 1, 6])) - @test nodes_visited == [8, 9, 1, 2, 3, 6, 5, 7, 4] + levels = ([8, 1, 6], [2, 3, 5, 7, 9], [4]) + @test sort(nodes_visited[1:3]) == sort(levels[1]) + @test sort(nodes_visited[4:8]) == sort(levels[2]) + @test sort(nodes_visited[9:end]) == sort(levels[3]) end