Commit e43dfa18 authored by Martin Larralde's avatar Martin Larralde
Browse files

Reduce memory consumption of query queues in `pyhmmer.hmmer`

parent a3f77269
......@@ -30,33 +30,63 @@ _S = typing.TypeVar("_S", DigitalSequence, DigitalMSA)
# --- Result class -----------------------------------------------------------
class _ResultBuffer:
class _Chore(typing.Generic[_Q]):
"""A chore for a worker thread.
Attributes:
query (`object`): The query object to be processed by the worker
thread. Exact type depends on the pipeline type.
event (`threading.Event`): An event flag to set when the query
is done being processed.
hits (`pyhmmer.plan7.TopHits`): The hits obtained after processing
the query.
exception (`BaseException`): An exception that occured while
processing the query.
"""
query: _Q
event: threading.Event
hits: typing.Optional[TopHits]
exception: typing.Optional[BaseException]
__slots__ = ("event", "hits", "exception")
__slots__ = ("query", "event", "hits", "exception")
def __init__(self) -> None:
def __init__(self, query: _Q) -> None:
"""Create a new chore from the given query.
"""
self.query = query
self.event = threading.Event()
self.hits = None
self.exception = None
def available(self) -> bool:
"""Return whether the chore is done and results are available.
"""
return self.event.is_set()
def wait(self, timeout=None) -> bool:
"""Wait for the chore to be done.
"""
return self.event.wait(timeout)
def get(self) -> TopHits:
"""Get the results of the chore, blocking if the chore was not done.
"""
self.event.wait()
if self.exception is not None:
raise self.exception
return typing.cast(TopHits, self.hits)
def set(self, hits: TopHits) -> None:
def complete(self, hits: TopHits) -> None:
"""Mark the chore as done and record ``hits`` as the results.
"""
self.hits = hits
self.event.set()
def fail(self, exception: BaseException) -> None:
"""Mark the chore as done and record ``exception`` as the error.
"""
self.exception = exception
self.event.set()
......@@ -99,7 +129,7 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self,
sequences: PipelineSearchTargets,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[_Q]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[_Q, int], None]],
......@@ -112,41 +142,36 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self.sequences = sequences
self.pipeline = pipeline_class(alphabet=alphabet, **options)
self.query_available: threading.Semaphore = query_available
self.query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]] = query_queue
self.query_queue: typing.Deque[typing.Optional[_Chore[_Q]]] = query_queue
self.query_count = query_count
self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback or self._none_callback
self.kill_switch = kill_switch
self.error: typing.Optional[BaseException] = None
def run(self) -> None:
while not self.kill_switch.is_set():
# attempt to get the next argument, with a timeout
# so that the thread can periodically check if it has
# been killed, even when the query queue is empty
# been killed, even when no queries are available
if not self.query_available.acquire(timeout=1):
continue
args = self.query_queue.popleft()
chore = self.query_queue.popleft()
# check if arguments from the queue are a poison-pill (`None`),
# in which case the thread will stop running
if args is None:
if chore is None:
return
else:
index, query, result_buffer = args
# process the arguments, making sure to capture any exception
# raised while processing the query, and then mark the hits
# as "found" using a `threading.Event` for each.
# process the query, making sure to capture any exception
# and then mark the hits as "found" using a `threading.Event`
try:
hits = self.process(index, query)
result_buffer.set(hits)
hits = self.process(chore.query)
chore.complete(hits)
except BaseException as exc:
self.error = exc
self.kill()
result_buffer.fail(exc)
chore.fail(exc)
def kill(self) -> None:
self.kill_switch.set()
def process(self, index: int, query: _Q) -> TopHits:
def process(self, query: _Q) -> TopHits:
hits = self.search(query)
self.callback(query, self.query_count.value) # type: ignore
self.pipeline.clear()
......@@ -167,7 +192,7 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
self,
sequences: PipelineSearchTargets,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[DigitalSequence]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]],
......@@ -198,7 +223,7 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
self,
sequences: PipelineSearchTargets,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[DigitalMSA]]],
query_count: multiprocessing.Value, # type: ignore
kill_switch: threading.Event,
callback: typing.Optional[typing.Callable[[DigitalMSA, int], None]],
......@@ -253,7 +278,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
def _new_thread(
self,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[_Q]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _PipelineThread[_Q]:
......@@ -273,23 +298,24 @@ class _Search(typing.Generic[_Q], abc.ABC):
# process each HMM iteratively and yield the result
# immediately so that the user can iterate over the
# TopHits one at a time
for index, query in enumerate(self.queries):
for query in self.queries:
query_count.value += 1
yield thread.process(index, query)
yield thread.process(query)
def _multi_threaded(self) -> typing.Iterator[TopHits]:
# create the semaphore which will be used to notify worker threads
# there is a new chore available
query_available = threading.Semaphore(0)
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
results: typing.Deque[_ResultBuffer] = collections.deque()
# create the queues to pass the query objects around, as well as
# atomic values that we use to synchronize the threads
results: typing.Deque[_Chore[_Q]] = collections.deque()
query_queue = collections.deque() # type: ignore
query_count = multiprocessing.Value(ctypes.c_ulong)
kill_switch = threading.Event()
# the query queue is bounded so that we only feed more queries
# if the worker threads are waiting for some
query_queue = collections.deque() # type: ignore
# additional type annotations
query: typing.Optional[_Q]
index: int
# the maximum number of queries to put in the query queue, to
# avoid filling the memory: using more than one query per thread
# permits some redundancy in case a query is very long to process,
query_bound = self.cpus * 2
# create and launch one pipeline thread per CPU
threads = []
......@@ -300,41 +326,30 @@ class _Search(typing.Generic[_Q], abc.ABC):
# catch exceptions to kill threads in the background before exiting
try:
# enumerate queries, so that we now the index of each query
# and we can yield the results in the same order
queries = enumerate(self.queries)
# initially feed one query per thread so that they can start
# working before we enter the main loop
for (index, query) in itertools.islice(queries, self.cpus):
# alternate between feeding queries to the threads and
# yielding back results, if available. the priority is
# given to filling the query queue, so that no worker
# ever idles.
for query in self.queries:
# get the next query and add it to the query queue
query_count.value += 1
query_result = _ResultBuffer()
query_queue.append((index, query, query_result))
chore = _Chore(query)
query_queue.append(chore)
query_available.release()
results.append(query_result)
# alternate between feeding queries to the threads and
# yielding back results, if available
while results:
# get the next query, or break the loop if there is no query
# left to process in the input iterator.
index, query = next(queries, (-1, None))
if query is None:
break
else:
query_count.value += 1
query_result = _ResultBuffer()
query_queue.append((index, query, query_result))
query_available.release()
results.append(query_result)
# yield the top hits for the next query, if available
if results[0].available():
yield results[0].get()
results.popleft()
results.append(chore)
# aggressively wait for the result with a very short
# timeout, and exit the loop if the queue is not full
while len(query_queue) >= query_bound:
if results[0].wait(timeout=0.01):
yield results[0].get()
results.popleft()
break
# now that we exhausted all queries, poison pill the
# threads so they stop on their own
# threads so they stop on their own gracefully
for _ in threads:
query_queue.append(None)
query_available.release()
# yield remaining results
# yield all remaining results, in order
while results:
yield results[0].get()
results.popleft()
......@@ -356,7 +371,7 @@ class _ModelSearch(typing.Generic[_M], _Search[_M]):
def _new_thread(
self,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, _M, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[_M]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _ModelPipelineThread[_M]:
......@@ -392,7 +407,7 @@ class _SequenceSearch(_Search[DigitalSequence]):
def _new_thread(
self,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[DigitalSequence]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _SequencePipelineThread:
......@@ -429,7 +444,7 @@ class _MSASearch(_Search[DigitalMSA]):
def _new_thread(
self,
query_available: threading.Semaphore,
query_queue: typing.Deque[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]],
query_queue: typing.Deque[typing.Optional[_Chore[DigitalMSA]]],
query_count: "multiprocessing.Value[int]", # type: ignore
kill_switch: threading.Event,
) -> _MSAPipelineThread:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment