diff --git a/numba/parfors/parfor_lowering.py b/numba/parfors/parfor_lowering.py index 7a67311f4cb..bbca774cebe 100644 --- a/numba/parfors/parfor_lowering.py +++ b/numba/parfors/parfor_lowering.py @@ -930,15 +930,22 @@ def find_setitems_block(setitems, itemsset, block, typemap): # used in a call then consider it unanalyzable and so # unavailable for hoisting. rhs = inst.value + def add_to_itemset(item): + assert isinstance(item, ir.Var), rhs + if getattr(typemap[item.name], "mutable", False): + itemsset.add(item.name) + if isinstance(rhs, ir.Expr): - if rhs.op in ["build_tuple", "build_list", "build_set", "build_map"]: + if rhs.op in ["build_tuple", "build_list", "build_set"]: for item in rhs.items: - if getattr(typemap[item.name], "mutable", False): - itemsset.add(item.name) + add_to_itemset(item) + elif rhs.op == "build_map": + for pair in rhs.items: + for item in pair: + add_to_itemset(item) elif rhs.op == "call": for item in list(rhs.args) + [x[1] for x in rhs.kws]: - if getattr(typemap[item.name], "mutable", False): - itemsset.add(item.name) + add_to_itemset(item) def find_setitems_body(setitems, itemsset, loop_body, typemap): """ diff --git a/numba/tests/test_parfors.py b/numba/tests/test_parfors.py index 556086438df..e333b05a587 100644 --- a/numba/tests/test_parfors.py +++ b/numba/tests/test_parfors.py @@ -3327,6 +3327,23 @@ def foo(): self.assertEqual(foo(), foo.py_func()) + def test_issue_9678_build_map(self): + def issue_9678(num_nodes): + out = 0 + for inode_uint in numba.prange(num_nodes): + inode = numba.int64(inode_uint) + p = {inode: 0.0} # mainly this build_map bytecode here + for _ in range(5): + p[inode] += 1 # and here + out += p[inode] + return out + + num_nodes = 12 + issue_9678_serial = numba.jit(parallel=False)(issue_9678) + issue_9678_parallel = numba.jit(parallel=True)(issue_9678) + self.assertEqual(issue_9678_serial(num_nodes), + issue_9678_parallel(num_nodes)) + @skip_parfors_unsupported class TestParforsDiagnostics(TestParforsBase):