From b3c2a7246adb6800e1eb7e8ead16ad4259e99130 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Thu, 4 Aug 2022 00:25:05 +0200
Subject: [PATCH] Avoid spawning a pool if `Mapper` methods are called with
 `threads=1`

---
 pyfastani/_fastani.pyx | 35 ++++++++++++++++++++++++++++++++---
 1 file changed, 32 insertions(+), 3 deletions(-)

diff --git a/pyfastani/_fastani.pyx b/pyfastani/_fastani.pyx
index 45b2022..168a46d 100644
--- a/pyfastani/_fastani.pyx
+++ b/pyfastani/_fastani.pyx
@@ -75,9 +75,10 @@ from _atomic_vector cimport atomic_vector
 # --- Python imports ---------------------------------------------------------
 
 import array
-import warnings
-import threading
+import os
 import multiprocessing.pool
+import threading
+import warnings
 
 # --- Constants --------------------------------------------------------------
 
@@ -286,6 +287,21 @@ cdef int _add_minimizers_prot(
 # --- Cython classes ---------------------------------------------------------
 
 
+class _DummyPool:
+    """A dummy `~ThreadPool` that runs everything in the main thread.
+    """
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        return False
+
+    def map(self, func, iterable):
+        for x in iterable:
+            func(x)
+
+
 cdef class _Map:
     """A private class wrapping a heap-allocated compute map.
     """
@@ -988,6 +1004,8 @@ cdef class Mapper(_Parameterized):
         cdef void*                             data
         cdef ssize_t                           slen
         cdef object                            mem
+        # parallel computation
+        cdef object                            pool
         # core genomic identity results
         cdef vector[CGI_Results]               results
         cdef CGI_Results                       result
@@ -996,13 +1014,24 @@ cdef class Mapper(_Parameterized):
         cdef uint64_t                          shared_length
         cdef list                              hits            = []
 
+        # run everything in main thread if `threads==1`, otherwise run a
+        # `ThreadPool` with the required number of threads, if specified.
+        if threads == 0:
+            threads = os.cpu_count() or 1
+        if threads == 1:
+            pool = _DummyPool()
+        elif threads > 1:
+            pool = multiprocessing.pool.ThreadPool(threads)
+        else:
+            raise ValueError(f"`threads` must be positive or null, got {threads!r}")
+
         # create a new mapper and a new l2 mapping vector
         map = _Map.__new__(_Map)
         map._map = new Map_t(self._param, self._sk[0], total_fragments, 0)
         final_mappings = _FinalMappings.__new__(_FinalMappings)
 
         # spawn a thread pool to map fragments in parallel for all the contigs
-        with multiprocessing.pool.ThreadPool(threads or None) as pool:
+        with pool:
             for contig in contigs:
                 # check length of contig is enough for computing mapping
                 slen = len(contig)
-- 
GitLab