Skip to content
Snippets Groups Projects
Commit b3c2a724 authored by Martin Larralde's avatar Martin Larralde
Browse files

Avoid spawning a pool if `Mapper` methods are called with `threads=1`

parent 37031d74
No related branches found
No related tags found
No related merge requests found
......@@ -75,9 +75,10 @@ from _atomic_vector cimport atomic_vector
# --- Python imports ---------------------------------------------------------
import array
import warnings
import threading
import os
import multiprocessing.pool
import threading
import warnings
# --- Constants --------------------------------------------------------------
......@@ -286,6 +287,21 @@ cdef int _add_minimizers_prot(
# --- Cython classes ---------------------------------------------------------
class _DummyPool:
"""A dummy `~ThreadPool` that runs everything in the main thread.
"""
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return False
def map(self, func, iterable):
for x in iterable:
func(x)
cdef class _Map:
"""A private class wrapping a heap-allocated compute map.
"""
......@@ -988,6 +1004,8 @@ cdef class Mapper(_Parameterized):
cdef void* data
cdef ssize_t slen
cdef object mem
# parallel computation
cdef object pool
# core genomic identity results
cdef vector[CGI_Results] results
cdef CGI_Results result
......@@ -996,13 +1014,24 @@ cdef class Mapper(_Parameterized):
cdef uint64_t shared_length
cdef list hits = []
# run everything in main thread if `threads==1`, otherwise run a
# `ThreadPool` with the required number of threads, if specified.
if threads == 0:
threads = os.cpu_count() or 1
if threads == 1:
pool = _DummyPool()
elif threads > 1:
pool = multiprocessing.pool.ThreadPool(threads)
else:
raise ValueError(f"`threads` must be positive or null, got {threads!r}")
# create a new mapper and a new l2 mapping vector
map = _Map.__new__(_Map)
map._map = new Map_t(self._param, self._sk[0], total_fragments, 0)
final_mappings = _FinalMappings.__new__(_FinalMappings)
# spawn a thread pool to map fragments in parallel for all the contigs
with multiprocessing.pool.ThreadPool(threads or None) as pool:
with pool:
for contig in contigs:
# check length of contig is enough for computing mapping
slen = len(contig)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment