From 4b3abeac45562672d44a0678c9a2fd033787cc75 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Wed, 3 Aug 2022 04:43:56 +0200
Subject: [PATCH] Rewrite `Mapper` methods so that `Mapper._query_fragment` can
 be parallelized

---
 include/fastani/map/compute_map.pxd |  38 ++--
 pyfastani/_fastani.pyx              | 260 +++++++++++++++-------------
 2 files changed, 164 insertions(+), 134 deletions(-)

diff --git a/include/fastani/map/compute_map.pxd b/include/fastani/map/compute_map.pxd
index edf684b..fc34e97 100644
--- a/include/fastani/map/compute_map.pxd
+++ b/include/fastani/map/compute_map.pxd
@@ -12,17 +12,8 @@ cdef extern from "map/include/computeMap.hpp" namespace "skch" nogil:
 
     cdef cppclass Map:
 
-        struct L1_candidateLocus_t:
-            seqno_t  seqId
-            offset_t rangeStartPos
-            offset_t rangeEndPos
-
-        struct L2_mapLocus_t:
-            seqno_t         seqId
-            offset_t        meanOptimalPos
-            Sketch.MIIter_t optimalStart
-            Sketch.MIIter_t optimalEnd
-            int             sharedSketchSize
+        Parameters &param
+        Sketch &refSketch
 
         ctypedef Sketch.MI_Type MinVec_Type
         ctypedef Sketch.MIIter_t MIIter_t
@@ -37,9 +28,24 @@ cdef extern from "map/include/computeMap.hpp" namespace "skch" nogil:
             function[void(MappingResult&)] f = nullptr
         )
 
-        void mapSingleQuerySeq[Q](Q&, MappingResultsVector_t&, ofstream&)
+        void mapSingleQuerySeq[Q](Q&, MappingResultsVector_t&, ofstream&) except +
         void doL1Mapping[Q, V](Q&, V&)
-        void computeL1CandidateRegions[Q, V1, V2](Q&, V1&, int, V2&)
-        void doL2Mapping[Q, V1, V2](Q&, V1&, V2&)
-        void computeL2MappedRegions[Q](Q&, L1_candidateLocus_t&, L2_mapLocus_t&)
-        void reportL2Mappings(MappingResultsVector_t&, ofstream&)
+        void computeL1CandidateRegions[Q, V1, V2](Q&, V1&, int, V2&) except +
+        void doL2Mapping[Q, V1, V2](Q&, V1&, V2&) except +
+        void computeL2MappedRegions[Q](Q&, L1_candidateLocus_t&, L2_mapLocus_t&) except +
+        void reportL2Mappings(MappingResultsVector_t&, ofstream&) except +
+
+
+cdef extern from "map/include/computeMap.hpp" namespace "skch::Map" nogil:
+
+    cdef struct L1_candidateLocus_t:
+        seqno_t  seqId
+        offset_t rangeStartPos
+        offset_t rangeEndPos
+
+    cdef struct L2_mapLocus_t:
+        seqno_t         seqId
+        offset_t        meanOptimalPos
+        Sketch.MIIter_t optimalStart
+        Sketch.MIIter_t optimalEnd
+        int             sharedSketchSize
diff --git a/pyfastani/_fastani.pyx b/pyfastani/_fastani.pyx
index 5f35793..84b3482 100644
--- a/pyfastani/_fastani.pyx
+++ b/pyfastani/_fastani.pyx
@@ -13,10 +13,13 @@ References:
 # --- C imports --------------------------------------------------------------
 
 cimport cython
+cimport cython.parallel
 cimport libcpp11.chrono
 from cython.operator cimport dereference, preincrement, postincrement
 from cpython.ref cimport Py_INCREF
 from cpython.list cimport PyList_New, PyList_SET_ITEM
+from cpython.pycapsule cimport PyCapsule_New, PyCapsule_GetPointer, PyCapsule_CheckExact
+from libc.stdio cimport printf
 from libc.string cimport memcpy
 from libc.limits cimport INT_MAX
 from libc.stdint cimport int64_t, uint64_t
@@ -30,22 +33,28 @@ from libcpp.string cimport string
 from libcpp.unordered_map cimport unordered_map
 from libcpp.vector cimport vector
 from libcpp11.fstream cimport ofstream
+from openmp cimport omp_lock_t, omp_init_lock, omp_set_lock, omp_unset_lock
 
 cimport fastani.map.map_stats
 from kseq cimport kseq_t, kstring_t
 from fastani.cgi.compute_core_identity cimport computeCGI
 from fastani.cgi.cgid_types cimport CGI_Results
 from fastani.map cimport base_types
-from fastani.map.compute_map cimport Map as Map_t
+from fastani.map.compute_map cimport (
+    Map as Map_t,
+    L1_candidateLocus_t,
+    L2_mapLocus_t
+)
 from fastani.map.common_func cimport addMinimizers, getHash
 from fastani.map.map_parameters cimport Parameters as Parameters_t
-from fastani.map.map_stats cimport estimateMinimumHitsRelaxed, recommendedWindowSize
+from fastani.map.map_stats cimport estimateMinimumHitsRelaxed, recommendedWindowSize, j2md, md_lower_bound
 from fastani.map.win_sketch cimport Sketch as Sketch_t
 from fastani.map.base_types cimport (
     hash_t,
     seqno_t,
     offset_t,
     ContigInfo as ContigInfo_t,
+    MappingResult as MappingResult_t,
     MappingResultsVector_t,
     MinimizerInfo as MinimizerInfo_t,
     MinimizerMetaData as MinimizerMetaData_t,
@@ -60,12 +69,14 @@ from fastani.map.base_types cimport (
 from _utils cimport kseq_ptr_t, toupper, distance
 from _unicode cimport *
 from _sequtils cimport copy_upper, reverse_complement
-
+from _atomic_vector cimport atomic_vector
 
 # --- Python imports ---------------------------------------------------------
 
 import array
 import warnings
+import threading
+import multiprocessing.pool
 
 # --- Constants --------------------------------------------------------------
 
@@ -274,6 +285,14 @@ cdef int _add_minimizers_prot(
 # --- Cython classes ---------------------------------------------------------
 
 
+cdef class _Map:
+    cdef Map_t* _map
+
+
+cdef class _FinalMappings:
+    cdef atomic_vector[MappingResult_t] _vec
+
+
 cdef class _Parameterized:
     """A base class for types wrapping a `skch::Parameters` C++ object.
     """
@@ -809,7 +828,7 @@ cdef class Mapper(_Parameterized):
         const void* data,
         const ssize_t slen,
         QueryMetaData_t[kseq_ptr_t, vector[MinimizerInfo_t]]& query,
-        vector[Map_t.L1_candidateLocus_t]& l1_mappings,
+        vector[L1_candidateLocus_t]& l1_mappings,
     ) nogil:
         """Compute L1 mappings for the given sequence block.
 
@@ -871,43 +890,43 @@ cdef class Mapper(_Parameterized):
         minimum_hits = estimateMinimumHitsRelaxed(query.sketchSize, param.kmerSize, param.percentageIdentity)
         map.computeL1CandidateRegions(query, seed_hits_l1, minimum_hits, l1_mappings)
 
-    @staticmethod
-    cdef void _query_fragment(
-        const Parameters_t& param,
-        const Sketch_t& sketch,
-        Map_t& map,
+    cpdef void _query_fragment(
+        self,
+        _Map map,
         const int i,
         const seqno_t seq_counter,
-        const int kind,
-        const void* data,
-        const ssize_t slen,
-        const size_t stride,
-        vector[Map_t.L1_candidateLocus_t]& l1_mappings,
-        MappingResultsVector_t& l2_mappings
-    ) nogil:
+        const unsigned char[::1] sequence,
+        _FinalMappings final_mappings,
+    ) except *:
         cdef kseq_t                                         kseq
         cdef QueryMetaData_t[kseq_ptr_t, Map_t.MinVec_Type] query
+        cdef vector[L1_candidateLocus_t]                    l1_mappings
+        cdef const unsigned char*                           fragment
 
-        query.kseq = &kseq
-        query.kseq.seq.s = NULL
-        query.kseq.seq.l = param.minReadLength
-        query.seqCounter = seq_counter + i
-
-        Mapper._do_l1_mappings(
-            # classes with configuration
-            param,
-            sketch,
-            map,
-            # start of the i-th fragment
-            kind,
-            <const void*> ((<const char*> data) + (i * param.minReadLength) * stride),
-            param.minReadLength,
-            # outputs
-            query,
-            l1_mappings,
-        )
+        with nogil:
+
+            fragment = &sequence[i*self._param.minReadLength]
+
+            query.kseq = &kseq
+            query.kseq.seq.s = NULL
+            query.kseq.seq.l = self._param.minReadLength
+            query.seqCounter = seq_counter + i
 
-        map.doL2Mapping(query, l1_mappings, l2_mappings)
+            Mapper._do_l1_mappings(
+                # classes with configuration
+                self._param,
+                self._sk[0],
+                map._map[0],
+                # start of the i-th fragment
+                PyUnicode_1BYTE_KIND,
+                fragment,
+                self._param.minReadLength,
+                # outputs
+                query,
+                l1_mappings,
+            )
+
+            map._map.doL2Mapping(query, l1_mappings, final_mappings._vec)
 
     cdef list _query_draft(self, object contigs):
         """Query the sketcher for the given contigs.
@@ -925,18 +944,11 @@ cdef class Mapper(_Parameterized):
         cdef uint64_t                          total_fragments = 0
         cdef uint64_t                          total_length    = 0
         # mapping parameters and reference
-        cdef Map_t*                            map
-        cdef Parameters_t*                     param           = &self._param
-        cdef Sketch_t*                         sketch          = self._sk
+        cdef _Map                              map
+        cdef _FinalMappings                    final_mappings
         # sequence as a unicode object
-        cdef const unsigned char[::1]          view
-        cdef int                               kind
-        cdef void*                             data
         cdef ssize_t                           slen
-        cdef size_t                            stride
         # core genomic identity results
-        cdef vector[Map_t.L1_candidateLocus_t] l1_mappings
-        cdef MappingResultsVector_t            final_mappings
         cdef vector[CGI_Results]               results
         cdef CGI_Results                       result
         # filtering and reporing results
@@ -944,98 +956,110 @@ cdef class Mapper(_Parameterized):
         cdef uint64_t                          shared_length
         cdef list                              hits            = []
 
+        cdef object                            l2_lock = threading.Lock()
+
         # create a new mapper with the given mapping result vector
-        map = new Map_t(param[0], sketch[0], total_fragments, 0)
+        map = _Map.__new__(_Map)
+        map._map = new Map_t(self._param, self._sk[0], total_fragments, 0)
+
+        # create the result vector
+        final_mappings = _FinalMappings.__new__(_FinalMappings)
+
+        #
+        with multiprocessing.pool.ThreadPool() as pool:
+
+            # iterate over contigs
+            for contig in contigs:
+
+                # check length of contig is enough for computing mapping
+                slen = len(contig)
+                if slen < min(self._param.windowSize, self._param.kmerSize, self._param.minReadLength):
+                    warnings.warn(
+                        (
+                            "Mapper received a short sequence relative to parameters, "
+                            "mapping will not be computed."
+                        ),
+                        UserWarning,
+                    )
+
+                # encode
+                if isinstance(contig, str):
+                    contig = contig.encode('ascii')
+
+                # # get a way to read each letter of the contig,
+                # # independently of it being `str`, `bytes`, `bytearray`, etc.
+                # if isinstance(contig, str):
+                #     # make sure the unicode string is in canonical form,
+                #     # --> won't be needed anymore in Python 3.12
+                #     IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 12:
+                #         PyUnicode_READY(contig)
+                #     # get kind and data for efficient indexing
+                #     kind = PyUnicode_KIND(contig)
+                #     data = PyUnicode_DATA(contig)
+                #     slen = PyUnicode_GET_LENGTH(contig)
+                #     if kind == PyUnicode_1BYTE_KIND:
+                #         stride = sizeof(Py_UCS1)
+                #     elif kind == PyUnicode_2BYTE_KIND:
+                #         stride = sizeof(Py_UCS2)
+                #     else:
+                #         stride = sizeof(Py_UCS4)
+                # else:
+                #     # attempt to view the contig as a buffer of contiguous bytes
+                #     view = contig
+                #     # pretend the bytes are an ASCII (UCS-1) encoded string
+                #     kind = PyUnicode_1BYTE_KIND
+                #     slen = view.shape[0]
+                #     stride = sizeof(Py_UCS1)
+                #     if slen != 0:
+                #         data = <void*> &view[0]
+
+                # compute the expected number of blocks
+                fragment_count = slen // self._param.minReadLength
+
+                # map the blocks
+                # for i in range(fragment_count):
+                #     self._query_fragment(
+                #         map,
+                #         # position in query
+                #         i,
+                #         total_fragments,
+                #         # sequence
+                #         contig,
+                #         # result vector
+                #         final_mappings
+                #     )
+                pool.map(
+                    lambda i: self._query_fragment(map, i, total_fragments, contig, final_mappings),
+                    range(fragment_count),
+                )
 
-        # iterate over contigs
-        for contig in contigs:
+                # record the number of fragments
+                total_fragments += fragment_count
+                total_length += slen
 
-            # get a way to read each letter of the contig,
-            # independently of it being `str`, `bytes`, `bytearray`, etc.
-            if isinstance(contig, str):
-                # make sure the unicode string is in canonical form,
-                # --> won't be needed anymore in Python 3.12
-                IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 12:
-                    PyUnicode_READY(contig)
-                # get kind and data for efficient indexing
-                kind = PyUnicode_KIND(contig)
-                data = PyUnicode_DATA(contig)
-                slen = PyUnicode_GET_LENGTH(contig)
-                if kind == PyUnicode_1BYTE_KIND:
-                    stride = sizeof(Py_UCS1)
-                elif kind == PyUnicode_2BYTE_KIND:
-                    stride = sizeof(Py_UCS2)
-                else:
-                    stride = sizeof(Py_UCS4)
-            else:
-                # attempt to view the contig as a buffer of contiguous bytes
-                view = contig
-                # pretend the bytes are an ASCII (UCS-1) encoded string
-                kind = PyUnicode_1BYTE_KIND
-                slen = view.shape[0]
-                stride = sizeof(Py_UCS1)
-                if slen != 0:
-                    data = <void*> &view[0]
-
-            # query if the sequence is large enough
-            if slen >= param.windowSize and slen >= param.kmerSize and slen >= param.minReadLength:
-                with nogil:
-                    # compute the expected number of blocks
-                    fragment_count = slen // param.minReadLength
-                    # map the blocks
-                    for i in range(fragment_count):
-                        l1_mappings.clear()
-                        Mapper._query_fragment(
-                            # classes with configuration
-                            param[0],
-                            sketch[0],
-                            map[0],
-                            # position in query
-                            i,
-                            total_fragments,
-                            # unicode sequence
-                            kind,
-                            data,
-                            slen,
-                            stride,
-                            # temporary vector
-                            l1_mappings,
-                            # result vector
-                            final_mappings
-                        )
-                    # record the number of fragments
-                    total_fragments += fragment_count
-                    total_length += slen
-            else:
-                warnings.warn(
-                    (
-                        "Mapper received a short sequence relative to parameters, "
-                        "mapping will not be computed."
-                    ),
-                    UserWarning,
-                )
+        #
+        # printf("FINALL MAPPINGS: %i\n", final_mappings._vec.size())
 
         # compute core genomic identity after successful mapping
         with nogil:
             computeCGI(
-                param[0],
-                final_mappings,
-                map[0],
+                self._param,
+                final_mappings._vec,
+                map._map[0],
                 self._sk[0],
                 total_fragments, # total query fragments
                 0, # queryFileNo, only used for visualization, ignored
                 string(), # fileName, only used for reporting, ignored
                 results,
             )
-        # free the map
-        del map
+
         # build and return the list of hits
         for result in results:
             assert result.refGenomeId < self._lengths.size()
             assert result.refGenomeId < len(self._names)
             min_length = min(total_length, self._lengths[result.refGenomeId])
-            shared_length = result.countSeq * param.minReadLength
-            if shared_length >= min_length * param.minFraction:
+            shared_length = result.countSeq * self._param.minReadLength
+            if shared_length >= min_length * self._param.minFraction:
                 hits.append(Hit(
                     name=self._names[result.refGenomeId],
                     identity=result.identity,
-- 
GitLab