Commit a3f77269 authored by Martin Larralde's avatar Martin Larralde
Browse files

Update `pyhmmer.hmmer` to use a `deque` with semaphores to pass query data

parent f6abd16c
......@@ -69,8 +69,8 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
Attributes:
sequence (`pyhmmer.plan7.PipelineSearchTargets`): The target
sequences to search for hits.
query_queue (`queue.Queue`): The queue used to pass queries between
threads. It contains both the query, its index so that the
query_queue (`collections.deque`): The queue used to pass queries
between threads. It contains the query, its index so that the
results can be returned in the same order, and a `_ResultBuffer`
where to store the result when the query has been processed.
query_count (`multiprocessing.Value`): An atomic counter storing
......@@ -98,7 +98,8 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[_Q, int], None]],
......@@ -110,7 +111,8 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self.options = options
self.sequences = sequences
self.pipeline = pipeline_class(alphabet=alphabet, **options)
self.query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]" = query_queue
self.query_available: threading.Semaphore = query_available
self.query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]] = query_queue
self.query_count = query_count
self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback or self._none_callback
self.kill_switch = kill_switch
......@@ -121,14 +123,12 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
# attempt to get the next argument, with a timeout
# so that the thread can periodically check if it has
# been killed, even when the query queue is empty
try:
args = self.query_queue.get(timeout=1)
except queue.Empty:
if not self.query_available.acquire(timeout=1):
continue
args = self.query_queue.popleft()
# check if arguments from the queue are a poison-pill (`None`),
# in which case the thread will stop running
if args is None:
self.query_queue.task_done()
return
else:
index, query, result_buffer = args
......@@ -137,7 +137,6 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
# as "found" using a `threading.Event` for each.
try:
hits = self.process(index, query)
self.query_queue.task_done()
result_buffer.set(hits)
except BaseException as exc:
self.error = exc
......@@ -167,7 +166,8 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]],
......@@ -178,6 +178,7 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
) -> None:
super().__init__(
sequences,
query_available,
query_queue,
query_count,
kill_switch,
......@@ -196,7 +197,8 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
def __init__(
self,
sequences: PipelineSearchTargets,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[DigitalMSA, int], None]],
......@@ -207,6 +209,7 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
) -> None:
super().__init__(
sequences,
query_available,
query_queue,
query_count,
kill_switch,
......@@ -249,7 +252,8 @@ class _Search(typing.Generic[_Q], abc.ABC):
@abc.abstractmethod
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _PipelineThread[_Q]:
......@@ -258,12 +262,13 @@ class _Search(typing.Generic[_Q], abc.ABC):
def _single_threaded(self) -> typing.Iterator[TopHits]:
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
query_queue = queue.Queue() # type: ignore
query_available = threading.Semaphore(0)
query_queue = collections.deque() # type: ignore
query_count = multiprocessing.Value(ctypes.c_ulong)
kill_switch = threading.Event()
# create the thread (to recycle code)
thread = self._new_thread(query_queue, query_count, kill_switch)
thread = self._new_thread(query_available, query_queue, query_count, kill_switch)
# process each HMM iteratively and yield the result
# immediately so that the user can iterate over the
......@@ -273,6 +278,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
yield thread.process(index, query)
def _multi_threaded(self) -> typing.Iterator[TopHits]:
query_available = threading.Semaphore(0)
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
results: typing.Deque[_ResultBuffer] = collections.deque()
......@@ -280,7 +286,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
kill_switch = threading.Event()
# the query queue is bounded so that we only feed more queries
# if the worker threads are waiting for some
query_queue = queue.Queue(maxsize=self.cpus) # type: ignore
query_queue = collections.deque() # type: ignore
# additional type annotations
query: typing.Optional[_Q]
index: int
......@@ -288,7 +294,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
# create and launch one pipeline thread per CPU
threads = []
for _ in range(self.cpus):
thread = self._new_thread(query_queue, query_count, kill_switch)
thread = self._new_thread(query_available, query_queue, query_count, kill_switch)
thread.start()
threads.append(thread)
......@@ -302,7 +308,8 @@ class _Search(typing.Generic[_Q], abc.ABC):
for (index, query) in itertools.islice(queries, self.cpus):
query_count.value += 1
query_result = _ResultBuffer()
query_queue.put((index, query, query_result))
query_queue.append((index, query, query_result))
query_available.release()
results.append(query_result)
# alternate between feeding queries to the threads and
# yielding back results, if available
......@@ -315,7 +322,8 @@ class _Search(typing.Generic[_Q], abc.ABC):
else:
query_count.value += 1
query_result = _ResultBuffer()
query_queue.put((index, query))
query_queue.append((index, query, query_result))
query_available.release()
results.append(query_result)
# yield the top hits for the next query, if available
if results[0].available():
......@@ -324,7 +332,8 @@ class _Search(typing.Generic[_Q], abc.ABC):
# now that we exhausted all queries, poison pill the
# threads so they stop on their own
for _ in threads:
query_queue.put(None)
query_queue.append(None)
query_available.release()
# yield remaining results
while results:
yield results[0].get()
......@@ -346,12 +355,14 @@ class _ModelSearch(typing.Generic[_M], _Search[_M]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _M, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _M, _ResultBuffer]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _ModelPipelineThread[_M]:
return _ModelPipelineThread(
self.sequences,
query_available,
query_queue,
query_count,
kill_switch,
......@@ -380,12 +391,14 @@ class _SequenceSearch(_Search[DigitalSequence]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _SequencePipelineThread:
return _SequencePipelineThread(
self.sequences,
query_available,
query_queue,
query_count,
kill_switch,
......@@ -415,12 +428,14 @@ class _MSASearch(_Search[DigitalMSA]):
def _new_thread(
self,
query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _MSAPipelineThread:
return _MSAPipelineThread(
self.sequences,
query_available,
query_queue,
query_count,
kill_switch,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment