"""
Search Utilities
================
Core utility classes for similarity search: score conversion, candidate representation, and extensible filters.
Overview
--------
This module provides building blocks for the search pipeline:
* **ScoreConverter**: Handles metric-specific score transformations
* **Candidate**: Immutable representation of search results
* **CandidateFilter**: Abstract base for custom filtering logic
* **Built-in Filters**: ValidGid, L2Distance, CosineSim, Length, ExactMatch, SequenceIdentity
These utilities are used internally by SearchEngine but can also be used directly for custom search pipelines.
Score Conversion
----------------
The ScoreConverter handles conversions between FAISS raw scores and user-facing outputs:
**For Cosine Similarity:**
* FAISS returns inner product scores (higher = more similar)
* ``return_similarity=True``: Output as-is [0, 1]
* ``return_similarity=False``: Convert to distance (1 - similarity)
**For L2 Distance:**
* FAISS returns squared L2 distance (lower = more similar)
* Always output as distance (no conversion)
Usage::
>>> converter = ScoreConverter(metric="cosine", return_similarity=True)
>>> output_score = converter.convert(raw_faiss_score=0.95)
>>> output_score
0.95
Candidate Representation
-------------------------
The Candidate dataclass provides a clean interface for search results:
Attributes:
score (float): Converted score/similarity
gid (int): Global sequence ID
header (str | None): Sequence header from database
length (int | None): Sequence length
stored_hash (int | None): 8-byte sequence hash for deduplication
Usage::
>>> candidate = Candidate(
... score=0.95,
... gid=12345,
... header="sp|P12345|PROT_HUMAN",
... length=234,
... stored_hash=123456789
... )
>>> candidate.as_tuple()
(0.95, 12345, "sp|P12345|PROT_HUMAN", 234)
Custom Filters
--------------
Extend CandidateFilter to create custom filtering logic:
Example - Filter by minimum score::
class MinScoreFilter(CandidateFilter):
def __init__(self, min_score: float):
self.min_score = min_score
def apply(self, candidate: Candidate, query_seq: str = None) -> bool:
return candidate.score >= self.min_score
def get_name(self) -> str:
return "min_score"
Built-in Filters
----------------
**ValidGidFilter**
Filters out invalid GIDs (< 0). Always active in search pipeline.
Usage::
filter = ValidGidFilter()
passes = filter.apply(candidate) # False if gid < 0
**L2DistanceFilter**
Filters by minimum L2 distance (for L2 metric).
Parameters:
min_distance (float): Minimum distance threshold
Usage::
filter = L2DistanceFilter(min_distance=0.5)
passes = filter.apply(candidate) # True if distance >= 0.5
**CosineSimFilter**
Filters by maximum cosine similarity (for cosine metric).
Parameters:
max_similarity (float): Maximum similarity threshold
return_similarity (bool): Whether scores are similarities or distances
Usage::
filter = CosineSimFilter(max_similarity=0.99, return_similarity=True)
passes = filter.apply(candidate) # True if similarity <= 0.99
**LengthFilter**
Filters by sequence length range.
Parameters:
min_len (int | None): Minimum length (inclusive)
max_len (int | None): Maximum length (inclusive)
Usage::
filter = LengthFilter(min_len=50, max_len=500)
passes = filter.apply(candidate) # True if 50 <= length <= 500
**ExactMatchFilter**
Filters out exact sequence matches using hash + full comparison.
Parameters:
query_hash (int): Hash of query sequence
seq_store (SequenceStore): Database for sequence lookup
Usage::
query_hash = SequenceStore.hash8(query_seq)
filter = ExactMatchFilter(query_hash, seq_store)
passes = filter.apply(candidate, query_seq) # False if exact match
**SequenceIdentityFilter**
Filters by maximum sequence identity.
Parameters:
max_identity (float): Maximum identity threshold (0-1)
denominator (str): Identity denominator ("query", "target", "min", "max", "avg")
seq_store (SequenceStore): Database for sequence lookup
identity_func (callable): Function computing identity
Usage::
def compute_identity(seq1, seq2, denom="query"):
# Your alignment logic here
return identity_score
filter = SequenceIdentityFilter(
max_identity=0.95,
denominator="query",
seq_store=seq_store,
identity_func=compute_identity
)
passes = filter.apply(candidate, query_seq) # True if identity < 0.95
Filter Pipeline
---------------
Filters are applied sequentially in SearchEngine. First failed filter stops evaluation:
Pipeline order:
1. ValidGidFilter (always first)
2. L2DistanceFilter or CosineSimFilter (embedding-level)
3. LengthFilter (metadata-level)
4. ExactMatchFilter (sequence-level, per-query)
5. SequenceIdentityFilter (alignment-level, per-query)
This ordering minimizes expensive operations (sequence fetches, alignments).
**Optimization Tips:**
1. Use length_min/max to pre-filter via SQL index (much faster than post-filter)
2. Place cheap filters before expensive ones
3. Use hash comparison before full sequence comparison
4. Consider overfetch parameter when using aggressive filters
Integration with SearchEngine
------------------------------
SearchEngine automatically builds and applies filters based on search parameters::
results = engine.search(
queries=queries,
k=100,
nprobe=128,
# These parameters create filters internally:
length_min=50, # -> LengthFilter
length_max=500, # -> LengthFilter
max_cosine_similarity=0.99, # -> CosineSimFilter
exclude_exact_sequence=True,# -> ExactMatchFilter (per-query)
sequence_identity_max=0.95 # -> SequenceIdentityFilter (per-query)
)
See Also
--------
* :class:`SearchEngine`: Main search interface using these utilities
* :class:`SequenceStore`: Database for sequence lookups in filters
* :mod:`starling.search`: Main search module
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Tuple
from starling.search.store import SequenceStore
[docs]
@dataclass
class Candidate:
"""Represents a search result candidate."""
score: float
gid: int
header: Optional[str]
length: Optional[int]
stored_hash: Optional[int] = None
[docs]
def as_tuple(self) -> Tuple[float, int, Optional[str], Optional[int]]:
"""Convert candidate to tuple format (score, gid, header, length)."""
return (self.score, self.gid, self.header, self.length)
# ========== Filter Classes ==========
[docs]
class CandidateFilter(ABC):
"""Base class for candidate filters."""
[docs]
@abstractmethod
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
"""Return True if candidate passes filter."""
pass
[docs]
@abstractmethod
def get_name(self) -> str:
"""Return filter name for logging."""
pass
[docs]
class ValidGidFilter(CandidateFilter):
"""Filter out invalid GIDs."""
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
"""Return True if candidate has a valid (non-negative) GID."""
return candidate.gid >= 0
[docs]
def get_name(self) -> str:
"""Return filter name for logging."""
return "gid"
[docs]
class L2DistanceFilter(CandidateFilter):
"""Filter by minimum L2 distance."""
[docs]
def __init__(self, min_distance: float):
self.min_distance = min_distance
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
"""Return True if candidate L2 distance meets minimum threshold."""
return candidate.score >= self.min_distance
[docs]
def get_name(self) -> str:
"""Return filter name for logging."""
return "l2"
[docs]
class CosineSimFilter(CandidateFilter):
"""Filter by maximum cosine similarity."""
[docs]
def __init__(self, max_similarity: float, return_similarity: bool):
self.max_similarity = max_similarity
self.return_similarity = return_similarity
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
"""Return True if candidate cosine similarity is below maximum threshold."""
cos_sim = candidate.score if self.return_similarity else (1.0 - candidate.score)
return cos_sim <= self.max_similarity
[docs]
def get_name(self) -> str:
"""Return filter name for logging."""
return "cosine"
[docs]
class LengthFilter(CandidateFilter):
"""Filter by sequence length range."""
[docs]
def __init__(self, min_len: Optional[int], max_len: Optional[int]):
self.min_len = min_len
self.max_len = max_len
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
if candidate.length is None:
return True
if self.min_len is not None and candidate.length < self.min_len:
return False
if self.max_len is not None and candidate.length > self.max_len:
return False
return True
[docs]
def get_name(self) -> str:
return "len"
[docs]
class ExactMatchFilter(CandidateFilter):
"""Filter out exact sequence matches."""
[docs]
def __init__(self, query_hash: int, seq_store: SequenceStore):
self.query_hash = query_hash
self.seq_store = seq_store
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
if candidate.stored_hash is None or candidate.stored_hash != self.query_hash:
return True
if query_seq is None:
return True
seq_val = self.seq_store.get_seq(candidate.gid)
return seq_val != query_seq
[docs]
def get_name(self) -> str:
return "exact"
[docs]
class SequenceIdentityFilter(CandidateFilter):
"""Filter by maximum sequence identity."""
[docs]
def __init__(
self,
max_identity: float,
denominator: str,
seq_store: SequenceStore,
identity_func,
):
self.max_identity = max_identity
self.denominator = denominator
self.seq_store = seq_store
self.identity_func = identity_func
[docs]
def apply(self, candidate: Candidate, query_seq: Optional[str] = None) -> bool:
if query_seq is None:
return True
seq_val = self.seq_store.get_seq(candidate.gid)
if seq_val is None:
return True
ident = self.identity_func(query_seq, seq_val, denom=self.denominator)
return ident < self.max_identity
[docs]
def get_name(self) -> str:
return "ident"
[docs]
class ScoreConverter:
"""Handles score/similarity conversion for different metrics."""
[docs]
def __init__(self, metric: str, return_similarity: bool):
self.metric = metric
self.return_similarity = return_similarity
[docs]
def convert(self, raw_score: float) -> float:
"""Convert raw FAISS score to output format."""
if self.metric == "cosine":
return float(raw_score if self.return_similarity else 1.0 - raw_score)
return float(raw_score)
[docs]
def to_similarity(self, score: float) -> float:
"""Convert score to similarity (for output formatting)."""
if self.metric == "cosine":
return score if self.return_similarity else (1.0 - score)
return score
[docs]
def to_score(self, similarity: float) -> float:
"""Convert similarity to score (for output formatting)."""
if self.metric == "cosine":
return similarity if self.return_similarity else (1.0 - similarity)
return similarity