From b3c2a7246adb6800e1eb7e8ead16ad4259e99130 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Thu, 4 Aug 2022 00:25:05 +0200 Subject: [PATCH] Avoid spawning a pool if `Mapper` methods are called with `threads=1` --- pyfastani/_fastani.pyx | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/pyfastani/_fastani.pyx b/pyfastani/_fastani.pyx index 45b2022..168a46d 100644 --- a/pyfastani/_fastani.pyx +++ b/pyfastani/_fastani.pyx @@ -75,9 +75,10 @@ from _atomic_vector cimport atomic_vector # --- Python imports --------------------------------------------------------- import array -import warnings -import threading +import os import multiprocessing.pool +import threading +import warnings # --- Constants -------------------------------------------------------------- @@ -286,6 +287,21 @@ cdef int _add_minimizers_prot( # --- Cython classes --------------------------------------------------------- +class _DummyPool: + """A dummy `~ThreadPool` that runs everything in the main thread. + """ + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return False + + def map(self, func, iterable): + for x in iterable: + func(x) + + cdef class _Map: """A private class wrapping a heap-allocated compute map. """ @@ -988,6 +1004,8 @@ cdef class Mapper(_Parameterized): cdef void* data cdef ssize_t slen cdef object mem + # parallel computation + cdef object pool # core genomic identity results cdef vector[CGI_Results] results cdef CGI_Results result @@ -996,13 +1014,24 @@ cdef class Mapper(_Parameterized): cdef uint64_t shared_length cdef list hits = [] + # run everything in main thread if `threads==1`, otherwise run a + # `ThreadPool` with the required number of threads, if specified. + if threads == 0: + threads = os.cpu_count() or 1 + if threads == 1: + pool = _DummyPool() + elif threads > 1: + pool = multiprocessing.pool.ThreadPool(threads) + else: + raise ValueError(f"`threads` must be positive or null, got {threads!r}") + # create a new mapper and a new l2 mapping vector map = _Map.__new__(_Map) map._map = new Map_t(self._param, self._sk[0], total_fragments, 0) final_mappings = _FinalMappings.__new__(_FinalMappings) # spawn a thread pool to map fragments in parallel for all the contigs - with multiprocessing.pool.ThreadPool(threads or None) as pool: + with pool: for contig in contigs: # check length of contig is enough for computing mapping slen = len(contig) -- GitLab