Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Martin Larralde
pyhmmer
Commits
f6abd16c
Commit
f6abd16c
authored
May 18, 2022
by
Martin Larralde
Browse files
Rewrite `pyhmmer.hmmer` threading code with a `Deque` to record query order
parent
c3d6ce03
Changes
1
Hide whitespace changes
Inline
Side-by-side
pyhmmer/hmmer.py
View file @
f6abd16c
...
...
@@ -28,6 +28,39 @@ _M = typing.TypeVar("_M", HMM, Profile, OptimizedProfile)
# the sequence type for the pipeline
_S
=
typing
.
TypeVar
(
"_S"
,
DigitalSequence
,
DigitalMSA
)
# --- Result class -----------------------------------------------------------
class
_ResultBuffer
:
event
:
threading
.
Event
hits
:
typing
.
Optional
[
TopHits
]
exception
:
typing
.
Optional
[
BaseException
]
__slots__
=
(
"event"
,
"hits"
,
"exception"
)
def
__init__
(
self
)
->
None
:
self
.
event
=
threading
.
Event
()
self
.
hits
=
None
self
.
exception
=
None
def
available
(
self
)
->
bool
:
return
self
.
event
.
is_set
()
def
get
(
self
)
->
TopHits
:
self
.
event
.
wait
()
if
self
.
exception
is
not
None
:
raise
self
.
exception
return
typing
.
cast
(
TopHits
,
self
.
hits
)
def
set
(
self
,
hits
:
TopHits
)
->
None
:
self
.
hits
=
hits
self
.
event
.
set
()
def
fail
(
self
,
exception
:
BaseException
)
->
None
:
self
.
exception
=
exception
self
.
event
.
set
()
# --- Pipeline threads -------------------------------------------------------
class
_PipelineThread
(
typing
.
Generic
[
_Q
],
threading
.
Thread
):
...
...
@@ -37,25 +70,15 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
sequence (`pyhmmer.plan7.PipelineSearchTargets`): The target
sequences to search for hits.
query_queue (`queue.Queue`): The queue used to pass queries between
threads. It contains both the query and its index, so that the
results can be returned in the same order.
threads. It contains both the query, its index so that the
results can be returned in the same order, and a `_ResultBuffer`
where to store the result when the query has been processed.
query_count (`multiprocessing.Value`): An atomic counter storing
the total number of queries that have currently been loaded.
Passed to the ``callback`` so that an UI can show the total
for a progress bar.
hits_queue (`queue.PriorityQueue`): The queue used to pass back
the `TopHits` to the main thread. The results are inserted
using the index of the query, so that the main thread can
pull results in order.
kill_switch (`threading.Event`): An event flag shared between
all worker threads, used to notify emergency exit.
hits_found (`list` of `threading.Event`): A list of event flags,
such that ``hits_found[i]`` is set when results have been
obtained for the query of index ``i``. This allows the main
thread to keep waiting for the right `TopHits` to yield even
if subsequent queries have already been treated, and to make
sure the next result returned by ``hits_queue.get`` will also
be of index ``i``.
callback (`callable`, optional): An optional callback to be called
after each query has been processed. It should accept two
arguments: the query object that was processed, and the total
...
...
@@ -75,11 +98,9 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
def
__init__
(
self
,
sequences
:
PipelineSearchTargets
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q
, _ResultBuffer
]]]"
,
query_count
:
multiprocessing
.
Value
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
callback
:
typing
.
Optional
[
typing
.
Callable
[[
_Q
,
int
],
None
]],
options
:
typing
.
Dict
[
str
,
typing
.
Any
],
pipeline_class
:
typing
.
Type
[
Pipeline
],
...
...
@@ -89,12 +110,10 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self
.
options
=
options
self
.
sequences
=
sequences
self
.
pipeline
=
pipeline_class
(
alphabet
=
alphabet
,
**
options
)
self
.
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]"
=
query_queue
self
.
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q
, _ResultBuffer
]]]"
=
query_queue
self
.
query_count
=
query_count
self
.
hits_queue
=
hits_queue
self
.
callback
:
"typing.Optional[typing.Callable[[_Q, int], None]]"
=
callback
or
self
.
_none_callback
self
.
callback
:
typing
.
Optional
[
typing
.
Callable
[[
_Q
,
int
],
None
]]
=
callback
or
self
.
_none_callback
self
.
kill_switch
=
kill_switch
self
.
hits_found
=
hits_found
self
.
error
:
typing
.
Optional
[
BaseException
]
=
None
def
run
(
self
)
->
None
:
...
...
@@ -112,28 +131,27 @@ class _PipelineThread(typing.Generic[_Q], threading.Thread):
self
.
query_queue
.
task_done
()
return
else
:
index
,
query
=
args
index
,
query
,
result_buffer
=
args
# process the arguments, making sure to capture any exception
# raised while processing the query, and then mark the hits
# as "found" using a `threading.Event` for each.
try
:
self
.
process
(
index
,
query
)
hits
=
self
.
process
(
index
,
query
)
self
.
query_queue
.
task_done
()
result_buffer
.
set
(
hits
)
except
BaseException
as
exc
:
self
.
error
=
exc
self
.
kill
()
return
finally
:
self
.
hits_found
[
index
].
set
()
result_buffer
.
fail
(
exc
)
def
kill
(
self
)
->
None
:
self
.
kill_switch
.
set
()
def
process
(
self
,
index
:
int
,
query
:
_Q
)
->
None
:
def
process
(
self
,
index
:
int
,
query
:
_Q
)
->
TopHits
:
hits
=
self
.
search
(
query
)
self
.
hits_queue
.
put
((
index
,
hits
))
self
.
callback
(
query
,
self
.
query_count
.
value
)
# type: ignore
self
.
pipeline
.
clear
()
return
hits
@
abc
.
abstractmethod
def
search
(
self
,
query
:
_Q
)
->
TopHits
:
...
...
@@ -149,11 +167,9 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
def
__init__
(
self
,
sequences
:
PipelineSearchTargets
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence
, _ResultBuffer
]]]"
,
query_count
:
multiprocessing
.
Value
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
callback
:
typing
.
Optional
[
typing
.
Callable
[[
DigitalSequence
,
int
],
None
]],
options
:
typing
.
Dict
[
str
,
typing
.
Any
],
pipeline_class
:
typing
.
Type
[
Pipeline
],
...
...
@@ -164,9 +180,7 @@ class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
sequences
,
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
,
callback
,
options
,
pipeline_class
,
...
...
@@ -182,11 +196,9 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
def
__init__
(
self
,
sequences
:
PipelineSearchTargets
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA
, _ResultBuffer
]]]"
,
query_count
:
multiprocessing
.
Value
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
callback
:
typing
.
Optional
[
typing
.
Callable
[[
DigitalMSA
,
int
],
None
]],
options
:
typing
.
Dict
[
str
,
typing
.
Any
],
pipeline_class
:
typing
.
Type
[
Pipeline
],
...
...
@@ -197,9 +209,7 @@ class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
sequences
,
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
,
callback
,
options
,
pipeline_class
,
...
...
@@ -239,39 +249,33 @@ class _Search(typing.Generic[_Q], abc.ABC):
@
abc
.
abstractmethod
def
_new_thread
(
self
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _Q
, _ResultBuffer
]]]"
,
query_count
:
"multiprocessing.Value[int]"
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
)
->
_PipelineThread
[
_Q
]:
return
NotImplemented
def
_single_threaded
(
self
)
->
typing
.
Iterator
[
TopHits
]:
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
hits_found
:
typing
.
List
[
threading
.
Event
]
=
[]
query_queue
=
queue
.
Queue
()
# type: ignore
query_count
=
multiprocessing
.
Value
(
ctypes
.
c_ulong
)
hits_queue
=
queue
.
PriorityQueue
()
# type: ignore
kill_switch
=
threading
.
Event
()
# create the thread (to recycle code)
thread
=
self
.
_new_thread
(
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
)
thread
=
self
.
_new_thread
(
query_queue
,
query_count
,
kill_switch
)
# process each HMM iteratively and yield the result
# immediately so that the user can iterate over the
# TopHits one at a time
for
index
,
query
in
enumerate
(
self
.
queries
):
query_count
.
value
+=
1
thread
.
process
(
index
,
query
)
yield
hits_queue
.
get_nowait
()[
1
]
yield
thread
.
process
(
index
,
query
)
def
_multi_threaded
(
self
)
->
typing
.
Iterator
[
TopHits
]:
# create the queues to pass the HMM objects around, as well as atomic
# values that we use to synchronize the threads
hits_found
:
typing
.
List
[
threading
.
Event
]
=
[]
hits_queue
=
queue
.
PriorityQueue
()
# type: ignore
results
:
typing
.
Deque
[
_ResultBuffer
]
=
collections
.
deque
()
query_count
=
multiprocessing
.
Value
(
ctypes
.
c_ulong
)
kill_switch
=
threading
.
Event
()
# the query queue is bounded so that we only feed more queries
...
...
@@ -284,7 +288,7 @@ class _Search(typing.Generic[_Q], abc.ABC):
# create and launch one pipeline thread per CPU
threads
=
[]
for
_
in
range
(
self
.
cpus
):
thread
=
self
.
_new_thread
(
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
)
thread
=
self
.
_new_thread
(
query_queue
,
query_count
,
kill_switch
)
thread
.
start
()
threads
.
append
(
thread
)
...
...
@@ -297,12 +301,12 @@ class _Search(typing.Generic[_Q], abc.ABC):
# working before we enter the main loop
for
(
index
,
query
)
in
itertools
.
islice
(
queries
,
self
.
cpus
):
query_count
.
value
+=
1
hits_found
.
append
(
threading
.
Event
())
query_queue
.
put
((
index
,
query
))
query_result
=
_ResultBuffer
()
query_queue
.
put
((
index
,
query
,
query_result
))
results
.
append
(
query_result
)
# alternate between feeding queries to the threads and
# yielding back results, if available
hits_yielded
=
0
while
hits_yielded
<
query_count
.
value
:
while
results
:
# get the next query, or break the loop if there is no query
# left to process in the input iterator.
index
,
query
=
next
(
queries
,
(
-
1
,
None
))
...
...
@@ -310,31 +314,21 @@ class _Search(typing.Generic[_Q], abc.ABC):
break
else
:
query_count
.
value
+=
1
hits_found
.
append
(
threading
.
Event
()
)
query_result
=
_ResultBuffer
()
query_queue
.
put
((
index
,
query
))
results
.
append
(
query_result
)
# yield the top hits for the next query, if available
if
hits_found
[
hits_yielded
].
is_set
():
yield
hits_queue
.
get_nowait
()[
1
]
hits_yielded
+=
1
if
results
[
0
].
available
():
yield
results
[
0
].
get
()
results
.
popleft
()
# now that we exhausted all queries, poison pill the
# threads so they stop on their own
for
_
in
threads
:
query_queue
.
put
(
None
)
# yield remaining results
while
hits_yielded
<
query_count
.
value
:
hits_found
[
hits_yielded
].
wait
()
yield
hits_queue
.
get_nowait
()[
1
]
hits_yielded
+=
1
except
queue
.
Empty
:
# the only way we can get queue.Empty is if a thread has set
# the flag for `hits_found[i]` without actually putting it in
# the queue: this only happens when a background thread raised
# an exception, so we must chain it
for
thread
in
threads
:
if
thread
.
error
is
not
None
:
raise
thread
.
error
from
None
# if this is exception is otherwise a bug, make sure to reraise it
raise
while
results
:
yield
results
[
0
].
get
()
results
.
popleft
()
except
BaseException
:
# make sure threads are killed to avoid being stuck,
# e.g. after a KeyboardInterrupt
...
...
@@ -352,19 +346,15 @@ class _ModelSearch(typing.Generic[_M], _Search[_M]):
def
_new_thread
(
self
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _M]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, _M
, _ResultBuffer
]]]"
,
query_count
:
"multiprocessing.Value[int]"
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
)
->
_ModelPipelineThread
[
_M
]:
return
_ModelPipelineThread
(
self
.
sequences
,
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
,
self
.
callback
,
self
.
options
,
self
.
pipeline_class
,
...
...
@@ -390,19 +380,15 @@ class _SequenceSearch(_Search[DigitalSequence]):
def
_new_thread
(
self
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence
, _ResultBuffer
]]]"
,
query_count
:
"multiprocessing.Value[int]"
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
)
->
_SequencePipelineThread
:
return
_SequencePipelineThread
(
self
.
sequences
,
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
,
self
.
callback
,
self
.
options
,
self
.
pipeline_class
,
...
...
@@ -429,19 +415,15 @@ class _MSASearch(_Search[DigitalMSA]):
def
_new_thread
(
self
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA]]]"
,
query_queue
:
"queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA
, _ResultBuffer
]]]"
,
query_count
:
"multiprocessing.Value[int]"
,
# type: ignore
hits_queue
:
"queue.PriorityQueue[typing.Tuple[int, TopHits]]"
,
kill_switch
:
threading
.
Event
,
hits_found
:
typing
.
List
[
threading
.
Event
],
)
->
_MSAPipelineThread
:
return
_MSAPipelineThread
(
self
.
sequences
,
query_queue
,
query_count
,
hits_queue
,
kill_switch
,
hits_found
,
self
.
callback
,
self
.
options
,
self
.
pipeline_class
,
...
...
@@ -493,7 +475,7 @@ def hmmsearch(
"""
# count the number of CPUs to use
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
or
1
runner
:
_ModelSearch
[
_M
]
=
_ModelSearch
(
queries
,
sequences
,
_cpus
,
callback
,
**
options
)
# type: ignore
return
runner
.
run
()
...
...
@@ -541,7 +523,7 @@ def phmmer(
Allow using `DigitalMSA` queries.
"""
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
or
1
_builder
=
Builder
(
Alphabet
.
amino
())
if
builder
is
None
else
builder
try
:
...
...
@@ -626,7 +608,7 @@ def nhmmer(
Allow using `Profile` and `OptimizedProfile` queries.
"""
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
_cpus
=
cpus
if
cpus
>
0
else
psutil
.
cpu_count
(
logical
=
False
)
or
os
.
cpu_count
()
or
1
_builder
=
Builder
(
Alphabet
.
dna
())
if
builder
is
None
else
builder
try
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment