Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bfs iterator for multiple source nodes #382

Merged
merged 16 commits into from
Jun 26, 2024
71 changes: 32 additions & 39 deletions src/iterators/bfs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ end
BFSVertexIteratorState

`BFSVertexIteratorState` is a struct to hold the current state of iteration
in BFS which is needed for the `Base.iterate()` function. A queue is used to
keep track of the vertices which will be visited during BFS. Since the queue
can contains repetitions of already visited nodes, we also keep track of that
in a `BitVector` so that to skip those nodes.
in BFS which is needed for the `Base.iterate()` function. We use two vectors,
one for the current level nodes and one from the next level nodes to visit
the graph. Since new levels can contains repetitions of already visited nodes,
we also keep track of that in a `BitVector` so as to skip those nodes.
"""
mutable struct BFSVertexIteratorState
visited::BitVector
queue::Vector{Int}
neighbor_idx::Int
curr_level::Vector{Int}
next_level::Vector{Int}
node_idx::Int
n_visited::Int
end

Expand All @@ -58,14 +59,17 @@ First iteration to visit vertices in a graph using breadth-first search.
function Base.iterate(t::BFSIterator{<:Integer})
visited = falses(nv(t.graph))
visited[t.source] = true
return (t.source, BFSVertexIteratorState(visited, [t.source], 1, 1))
state = BFSVertexIteratorState(visited, [t.source], Int[], 0, 0)
return Base.iterate(t, state)
end

function Base.iterate(t::BFSIterator{<:AbstractArray})
visited = falses(nv(t.graph))
visited[first(t.source)] = true
state = BFSVertexIteratorState(visited, copy(t.source), 1, 1)
return (first(t.source), state)
curr_level = unique(s for s in t.source)
sort!(curr_level)
visited[curr_level] .= true
state = BFSVertexIteratorState(visited, curr_level, Int[], 0, 0)
return Base.iterate(t, state)
end

"""
Expand All @@ -74,36 +78,25 @@ end
Iterator to visit vertices in a graph using breadth-first search.
"""
function Base.iterate(t::BFSIterator, state::BFSVertexIteratorState)
graph, visited, queue = t.graph, state.visited, state.queue
while !isempty(queue)
if state.n_visited == nv(graph)
return nothing
end
# we visit the first node in the queue
node_start = first(queue)
if !visited[node_start]
visited[node_start] = true
state.n_visited += 1
return (node_start, state)
end
# which means we arrive here when the first node was visited.
neigh = outneighbors(graph, node_start)
if state.neighbor_idx <= length(neigh)
node = neigh[state.neighbor_idx]
# we update the idx of the neighbor we will visit,
# if it is already visited, we repeat
state.neighbor_idx += 1
if !visited[node]
push!(queue, node)
state.visited[node] = true
state.n_visited += 1
return (node, state)
# we fill nodes in this level
if state.node_idx == length(state.curr_level)
state.n_visited += length(state.curr_level)
state.n_visited == nv(t.graph) && return nothing
@inbounds for node in state.curr_level
Tortar marked this conversation as resolved.
Show resolved Hide resolved
for adj_node in outneighbors(t.graph, node)
if !state.visited[adj_node]
push!(state.next_level, adj_node)
state.visited[adj_node] = true
end
end
else
# when the first node and its neighbors are visited
# we remove the first node of the queue
popfirst!(queue)
state.neighbor_idx = 1
end
length(state.next_level) == 0 && return nothing
state.curr_level, state.next_level = state.next_level, empty!(state.curr_level)
sort!(state.curr_level)
state.node_idx = 0
end
# we visit all nodes in this level
state.node_idx += 1
@inbounds node = state.curr_level[state.node_idx]
Tortar marked this conversation as resolved.
Show resolved Hide resolved
return (node, state)
end
11 changes: 9 additions & 2 deletions test/iterators/bfs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
end
end
nodes_visited = collect(BFSIterator(g2, [1, 6]))
@test nodes_visited == [1, 2, 3, 6, 5, 7, 4]
levels = ([1, 6], [2, 3, 5, 7], [4])
@test sort(nodes_visited[1:2]) == sort(levels[1])
@test sort(nodes_visited[3:6]) == sort(levels[2])
@test sort(nodes_visited[7:end]) == sort(levels[3])

nodes_visited = collect(BFSIterator(g2, [8, 1, 6]))
@test nodes_visited == [8, 9, 1, 2, 3, 6, 5, 7, 4]
levels = ([8, 1, 6], [2, 3, 5, 7, 9], [4])
@test sort(nodes_visited[1:3]) == sort(levels[1])
@test sort(nodes_visited[4:8]) == sort(levels[2])
@test sort(nodes_visited[9:end]) == sort(levels[3])
end
Loading