Skip to content

Commit

Permalink
Merge pull request numba#9691 from sklam/fix/iss9678
Browse files Browse the repository at this point in the history
Fix numba#9678. parfor issue with build_map
  • Loading branch information
esc authored Aug 14, 2024
2 parents abc8469 + e3b6bdf commit 4fc11d4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
17 changes: 12 additions & 5 deletions numba/parfors/parfor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,15 +930,22 @@ def find_setitems_block(setitems, itemsset, block, typemap):
# used in a call then consider it unanalyzable and so
# unavailable for hoisting.
rhs = inst.value
def add_to_itemset(item):
assert isinstance(item, ir.Var), rhs
if getattr(typemap[item.name], "mutable", False):
itemsset.add(item.name)

if isinstance(rhs, ir.Expr):
if rhs.op in ["build_tuple", "build_list", "build_set", "build_map"]:
if rhs.op in ["build_tuple", "build_list", "build_set"]:
for item in rhs.items:
if getattr(typemap[item.name], "mutable", False):
itemsset.add(item.name)
add_to_itemset(item)
elif rhs.op == "build_map":
for pair in rhs.items:
for item in pair:
add_to_itemset(item)
elif rhs.op == "call":
for item in list(rhs.args) + [x[1] for x in rhs.kws]:
if getattr(typemap[item.name], "mutable", False):
itemsset.add(item.name)
add_to_itemset(item)

def find_setitems_body(setitems, itemsset, loop_body, typemap):
"""
Expand Down
17 changes: 17 additions & 0 deletions numba/tests/test_parfors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,6 +3327,23 @@ def foo():

self.assertEqual(foo(), foo.py_func())

def test_issue_9678_build_map(self):
def issue_9678(num_nodes):
out = 0
for inode_uint in numba.prange(num_nodes):
inode = numba.int64(inode_uint)
p = {inode: 0.0} # mainly this build_map bytecode here
for _ in range(5):
p[inode] += 1 # and here
out += p[inode]
return out

num_nodes = 12
issue_9678_serial = numba.jit(parallel=False)(issue_9678)
issue_9678_parallel = numba.jit(parallel=True)(issue_9678)
self.assertEqual(issue_9678_serial(num_nodes),
issue_9678_parallel(num_nodes))


@skip_parfors_unsupported
class TestParforsDiagnostics(TestParforsBase):
Expand Down

0 comments on commit 4fc11d4

Please sign in to comment.