Commit f6abd16c authored by Martin Larralde's avatar Martin Larralde
Browse files

Rewrite `pyhmmer.hmmer` threading code with a `Deque` to record query order

parent c3d6ce03
......@@ -28,6 +28,39 @@ _M = typing.TypeVar("_M", HMM, Profile, OptimizedProfile)
# the sequence type for the pipeline
_S = typing.TypeVar("_S", DigitalSequence, DigitalMSA)
# --- Result class -----------------------------------------------------------
class _ResultBuffer:
event: threading.Event
hits: typing.Optional[TopHits]
exception: typing.Optional[BaseException]
__slots__ = ("event", "hits", "exception")
def __init__(self) -> None:
self.event = threading.Event()
self.hits = None
self.exception = None
def available(self) -> bool:
return self.event.is_set()
def get(self) -> TopHits:
self.event.wait()
if self.exception is not None:
raise self.exception
return typing.cast(TopHits, self.hits)
def set(self, hits: TopHits) -> None:
self.hits = hits
self.event.set()
def fail(self, exception: BaseException) -> None:
self.exception = exception
self.event.set()
# --- Pipeline threads -------------------------------------------------------
class _PipelineThread(typing.Generic[_Q], threading.Thread):
......@@ -37,25 +70,15 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
sequence (`pyhmmer.plan7.PipelineSearchTargets`): The target
sequences to search for hits.
query_queue (`queue.Queue`): The queue used to pass queries between
threads. It contains both the query and its index, so that the
results can be returned in the same order.
threads. It contains both the query, its index so that the
results can be returned in the same order, and a `_ResultBuffer`
where to store the result when the query has been processed.
query_count (`multiprocessing.Value`): An atomic counter storing
the total number of queries that have currently been loaded.
Passed to the ``callback`` so that an UI can show the total
for a progress bar.
hits_queue (`queue.PriorityQueue`): The queue used to pass back
the `TopHits` to the main thread. The results are inserted
using the index of the query, so that the main thread can
pull results in order.
kill_switch (`threading.Event`): An event flag shared between
all worker threads, used to notify emergency exit.
hits_found (`list` of `threading.Event`): A list of event flags,
such that ``hits_found[i]`` is set when results have been
obtained for the query of index ``i``. This allows the main
thread to keep waiting for the right `TopHits` to yield even
if subsequent queries have already been treated, and to make
sure the next result returned by ``hits_queue.get`` will also
be of index ``i``.
callback (`callable`, optional): An optional callback to be called
after each query has been processed. It should accept two
arguments: the query object that was processed, and the total
......@@ -75,11 +98,9 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
query_count: multiprocessing.Value, # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
callback: typing.Optional[typing.Callable[[_Q, int], None]],
options: typing.Dict[str, typing.Any],
pipeline_class: typing.Type[Pipeline],
......@@ -89,12 +110,10 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self.options = options
self.sequences = sequences
self.pipeline = pipeline_class(alphabet=alphabet, **options)
self.query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]" = query_queue
self.query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]" = query_queue
self.query_count = query_count
self.hits_queue = hits_queue
self.callback: "typing.Optional[typing.Callable[[_Q, int], None]]" = callback or self._none_callback
self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback or self._none_callback
self.kill_switch = kill_switch
self.hits_found = hits_found
self.error: typing.Optional[BaseException] = None
def run(self) -> None:
......@@ -112,28 +131,27 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self.query_queue.task_done()
return
else:
index, query = args
index, query, result_buffer = args
# process the arguments, making sure to capture any exception
# raised while processing the query, and then mark the hits
# as "found" using a `threading.Event` for each.
try:
self.process(index, query)
hits = self.process(index, query)
self.query_queue.task_done()
result_buffer.set(hits)
except BaseException as exc:
self.error = exc
self.kill()
return
finally:
self.hits_found[index].set()
result_buffer.fail(exc)
def kill(self) -> None:
self.kill_switch.set()
def process(self, index: int, query: _Q) -> None:
def process(self, index: int, query: _Q) -> TopHits:
hits = self.search(query)
self.hits_queue.put((index, hits))
self.callback(query, self.query_count.value) # type: ignore
self.pipeline.clear()
return hits
@abc.abstractmethod
def search(self, query: _Q) -> TopHits:
......@@ -149,11 +167,9 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
query_count: multiprocessing.Value, # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]],
options: typing.Dict[str, typing.Any],
pipeline_class: typing.Type[Pipeline],
......@@ -164,9 +180,7 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
sequences,
query_queue,
query_count,
hits_queue,
kill_switch,
hits_found,
callback,
options,
pipeline_class,
......@@ -182,11 +196,9 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
query_count: multiprocessing.Value, # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
callback: typing.Optional[typing.Callable[[DigitalMSA, int], None]],
options: typing.Dict[str, typing.Any],
pipeline_class: typing.Type[Pipeline],
......@@ -197,9 +209,7 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
sequences,
query_queue,
query_count,
hits_queue,
kill_switch,
hits_found,
callback,
options,
pipeline_class,
......@@ -239,39 +249,33 @@ class _Search(typing.Generic[_Q], abc.ABC):
@abc.abstractmethod
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
query_count: "multiprocessing.Value[int]", # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
) -> _PipelineThread[_Q]:
return NotImplemented
def _single_threaded(self) -> typing.Iterator[TopHits]:
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
hits_found: typing.List[threading.Event] = []
query_queue = queue.Queue() # type: ignore
query_count = multiprocessing.Value(ctypes.c_ulong)
hits_queue = queue.PriorityQueue() # type: ignore
kill_switch = threading.Event()
# create the thread (to recycle code)
thread = self._new_thread(query_queue, query_count, hits_queue, kill_switch, hits_found)
thread = self._new_thread(query_queue, query_count, kill_switch)
# process each HMM iteratively and yield the result
# immediately so that the user can iterate over the
# TopHits one at a time
for index, query in enumerate(self.queries):
query_count.value += 1
thread.process(index, query)
yield hits_queue.get_nowait()[1]
yield thread.process(index, query)
def _multi_threaded(self) -> typing.Iterator[TopHits]:
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
hits_found: typing.List[threading.Event] = []
hits_queue = queue.PriorityQueue() # type: ignore
results: typing.Deque[_ResultBuffer] = collections.deque()
query_count = multiprocessing.Value(ctypes.c_ulong)
kill_switch = threading.Event()
# the query queue is bounded so that we only feed more queries
......@@ -284,7 +288,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
# create and launch one pipeline thread per CPU
threads = []
for _ in range(self.cpus):
thread = self._new_thread(query_queue, query_count, hits_queue, kill_switch, hits_found)
thread = self._new_thread(query_queue, query_count, kill_switch)
thread.start()
threads.append(thread)
......@@ -297,12 +301,12 @@ class _Search(typing.Generic[_Q], abc.ABC):
# working before we enter the main loop
for (index, query) in itertools.islice(queries, self.cpus):
query_count.value += 1
hits_found.append(threading.Event())
query_queue.put((index, query))
query_result = _ResultBuffer()
query_queue.put((index, query, query_result))
results.append(query_result)
# alternate between feeding queries to the threads and
# yielding back results, if available
hits_yielded = 0
while hits_yielded < query_count.value:
while results:
# get the next query, or break the loop if there is no query
# left to process in the input iterator.
index, query = next(queries, (-1, None))
......@@ -310,31 +314,21 @@ class _Search(typing.Generic[_Q], abc.ABC):
break
else:
query_count.value += 1
hits_found.append(threading.Event())
query_result = _ResultBuffer()
query_queue.put((index, query))
results.append(query_result)
# yield the top hits for the next query, if available
if hits_found[hits_yielded].is_set():
yield hits_queue.get_nowait()[1]
hits_yielded += 1
if results[0].available():
yield results[0].get()
results.popleft()
# now that we exhausted all queries, poison pill the
# threads so they stop on their own
for _ in threads:
query_queue.put(None)
# yield remaining results
while hits_yielded < query_count.value:
hits_found[hits_yielded].wait()
yield hits_queue.get_nowait()[1]
hits_yielded += 1
except queue.Empty:
# the only way we can get queue.Empty is if a thread has set
# the flag for `hits_found[i]` without actually putting it in
# the queue: this only happens when a background thread raised
# an exception, so we must chain it
for thread in threads:
if thread.error is not None:
raise thread.error from None
# if this is exception is otherwise a bug, make sure to reraise it
raise
while results:
yield results[0].get()
results.popleft()
except BaseException:
# make sure threads are killed to avoid being stuck,
# e.g. after a KeyboardInterrupt
......@@ -352,19 +346,15 @@ class _ModelSearch(typing.Generic[_M], _Search[_M]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _M]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _M, _ResultBuffer]]]",
query_count: "multiprocessing.Value[int]", # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
) -> _ModelPipelineThread[_M]:
return _ModelPipelineThread(
self.sequences,
query_queue,
query_count,
hits_queue,
kill_switch,
hits_found,
self.callback,
self.options,
self.pipeline_class,
......@@ -390,19 +380,15 @@ class _SequenceSearch(_Search[DigitalSequence]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
query_count: "multiprocessing.Value[int]", # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
) -> _SequencePipelineThread:
return _SequencePipelineThread(
self.sequences,
query_queue,
query_count,
hits_queue,
kill_switch,
hits_found,
self.callback,
self.options,
self.pipeline_class,
......@@ -429,19 +415,15 @@ class _MSASearch(_Search[DigitalMSA]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA]]]",
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
query_count: "multiprocessing.Value[int]", # type: ignore
hits_queue: "queue.PriorityQueue[typing.Tuple[int, TopHits]]",
kill_switch: threading.Event,
hits_found: typing.List[threading.Event],
) -> _MSAPipelineThread:
return _MSAPipelineThread(
self.sequences,
query_queue,
query_count,
hits_queue,
kill_switch,
hits_found,
self.callback,
self.options,
self.pipeline_class,
......@@ -493,7 +475,7 @@ def hmmsearch(
"""
# count the number of CPUs to use
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count()
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
runner: _ModelSearch[_M] = _ModelSearch(queries, sequences, _cpus, callback, **options) # type: ignore
return runner.run()
......@@ -541,7 +523,7 @@ def phmmer(
Allow using `DigitalMSA` queries.
"""
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count()
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
_builder = Builder(Alphabet.amino()) if builder is None else builder
try:
......@@ -626,7 +608,7 @@ def nhmmer(
Allow using `Profile` and `OptimizedProfile` queries.
"""
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count()
_cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
_builder = Builder(Alphabet.dna()) if builder is None else builder
try:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment