Skip to content

API reference

Auto-generated from docstrings. The pipeline is the primary public API; the orchestration stages and evaluation metrics are documented for programmatic use.

Pipeline

trialmatchai.pipeline

The single TrialMatchAI pipeline: an ordered registry of idempotent stages.

Every command is a slice of this one pipeline. Each stage wraps an already-idempotent orchestration function (it internally skips work that is done), so the driver only decides which stages to run from the user's selection (--only / --skip / --from / --to) and which to force (--force).

Because each stage is idempotent, running the whole pipeline from any starting state "just works": finished stages are cheap no-ops, unfinished ones run. That is the "one e2e workflow, maximally modular, never redo finished work" contract — a stage is the unit of modularity, and the e2e run is simply "run every stage".

STAGES module-attribute

STAGES = (
    Stage(
        "prepare",
        _run_prepare,
        "embed + entity-annotate the trial corpus",
    ),
    Stage(
        "concepts",
        _run_concepts,
        "build the entity-linking concept store",
    ),
    Stage(
        "link",
        _run_link,
        "link extracted entities to concept IDs (idempotent)",
    ),
    Stage(
        "index",
        _run_index,
        "build the LanceDB search tables",
    ),
    Stage(
        "ingest",
        _run_ingest,
        "import patient inputs into canonical profiles",
    ),
    Stage(
        "expand",
        _run_expand,
        "CoT query expansion of patient summaries",
    ),
    Stage(
        "match",
        _run_match,
        "retrieval + reranking + CoT eligibility + ranking",
    ),
    Stage(
        "eval",
        _run_eval,
        "score results against qrels (benchmark runs)",
    ),
)

StageContext dataclass

Everything the stages need, resolved once and threaded through the run.

Source code in src/trialmatchai/pipeline.py
@dataclass
class StageContext:
    """Everything the stages need, resolved once and threaded through the run."""

    config: dict[str, Any]
    trials_json_folder: Path | None = None
    processed_trials_folder: Path = Path("data/processed_trials")
    processed_criteria_folder: Path = Path("data/processed_criteria")
    inputs: list[str] = field(default_factory=list)
    input_format: str = "auto"
    with_entities: bool = True
    nct_filter: set[str] | None = None
    concepts: str | None = None  # "open" -> build the open concept store
    concept_csv: str | None = None
    synonym_csv: str | None = None
    qrels: dict | None = None  # provided by the TREC preset -> enables eval
    results_dir: Path | None = None
    force: set[str] = field(default_factory=set)

    def forced(self, name: str) -> bool:
        return name in self.force or "all" in self.force

Stage dataclass

Source code in src/trialmatchai/pipeline.py
@dataclass(frozen=True)
class Stage:
    name: str
    run: Callable[[StageContext], None]
    help: str

select_stages

select_stages(
    *, only=None, skip=(), from_stage=None, to_stage=None
)

Resolve the user's selection into an ordered list of stages to run.

Source code in src/trialmatchai/pipeline.py
def select_stages(
    *,
    only: Sequence[str] | None = None,
    skip: Sequence[str] = (),
    from_stage: str | None = None,
    to_stage: str | None = None,
) -> list[Stage]:
    """Resolve the user's selection into an ordered list of stages to run."""
    if only:
        _validate(only)
        chosen = set(only)
        return [s for s in STAGES if s.name in chosen]

    for endpoint in (from_stage, to_stage):
        if endpoint is not None:
            _validate([endpoint])
    _validate(skip)

    start = STAGE_NAMES.index(from_stage) if from_stage else 0
    end = STAGE_NAMES.index(to_stage) + 1 if to_stage else len(STAGES)
    if start > end - 1:
        raise ValueError(f"--from {from_stage} is after --to {to_stage}")
    skipped = set(skip)
    return [s for s in STAGES[start:end] if s.name not in skipped]

run_pipeline

run_pipeline(
    ctx,
    *,
    only=None,
    skip=(),
    from_stage=None,
    to_stage=None,
)

Run the selected pipeline slice, freeing GPU models once at the end.

Source code in src/trialmatchai/pipeline.py
def run_pipeline(
    ctx: StageContext,
    *,
    only: Sequence[str] | None = None,
    skip: Sequence[str] = (),
    from_stage: str | None = None,
    to_stage: str | None = None,
) -> int:
    """Run the selected pipeline slice, freeing GPU models once at the end."""
    stages = select_stages(only=only, skip=skip, from_stage=from_stage, to_stage=to_stage)
    if not stages:
        logger.warning("No stages selected; nothing to do.")
        return 0
    logger.info("Pipeline: %s", " -> ".join(s.name for s in stages))
    try:
        for stage in stages:
            logger.info("================ stage: %s ================", stage.name)
            stage.run(ctx)
    finally:
        from trialmatchai.orchestration import free_models

        free_models()
    logger.info("Pipeline complete: %s", " -> ".join(s.name for s in stages))
    return 0

Orchestration stages

trialmatchai.orchestration

Idempotent end-to-end orchestration for TrialMatchAI.

Chains the three pipeline stages — ingest patient inputs, build the search index, run matching — and skips work that is already done:

  • ingest: a patient is skipped if its canonical profile already exists.
  • index: a stage is skipped if the search tables already exist.
  • match: a patient is skipped if it already has a non-empty ranked_trials.json.

Both the general trialmatchai e2e command and the TREC preset are thin wrappers over these stages, so idempotency behaves identically everywhere.

ingest_inputs

ingest_inputs(
    config,
    inputs,
    *,
    input_format="auto",
    with_entities=True,
    force=False,
)

Import patient inputs (any supported format) into canonical profiles.

Skips a patient whose profile already exists unless force. Returns the number of profiles available afterwards.

Source code in src/trialmatchai/orchestration.py
def ingest_inputs(
    config: Dict[str, Any],
    inputs: Sequence[str | Path],
    *,
    input_format: str = "auto",
    with_entities: bool = True,
    force: bool = False,
) -> int:
    """Import patient inputs (any supported format) into canonical profiles.

    Skips a patient whose profile already exists unless ``force``. Returns the
    number of profiles available afterwards.
    """
    from trialmatchai.interop.exporters import profile_to_matching_summary
    from trialmatchai.interop.importers import import_patient_path

    patient_cfg = config.get("patient_inputs", {})
    profile_dir = Path(patient_cfg.get("profile_dir", "data/patients/profiles"))
    summary_dir = Path(patient_cfg.get("summary_dir", "data/patients/summaries"))
    profile_dir.mkdir(parents=True, exist_ok=True)
    summary_dir.mkdir(parents=True, exist_ok=True)

    entity_annotator = _maybe_entity_annotator(config) if with_entities else None
    strict = bool(patient_cfg.get("strict_validation", False))

    imported = 0
    for raw in inputs:
        profiles = import_patient_path(
            raw,
            input_format=input_format,
            entity_annotator=entity_annotator,
            strict=strict,
        )
        for profile in profiles:
            profile_path = profile_dir / f"{profile.patient_id}.json"
            if not force and is_valid_json_file(str(profile_path)):
                logger.info("Ingest skipped (exists): %s", profile.patient_id)
                continue
            # Summary first; the profile JSON is written last (atomically) as the
            # completion marker, so a crash between them re-imports rather than
            # leaving a profile with no summary.
            write_json_file(
                profile_to_matching_summary(profile),
                str(summary_dir / f"{profile.patient_id}.json"),
            )
            write_json_file(
                profile.model_dump(mode="json", exclude_none=True),
                str(profile_path),
            )
            imported += 1
            logger.info("Ingested patient %s", profile.patient_id)

    total = len(list(profile_dir.glob("*.json")))
    logger.info("Ingest stage: %s new, %s profiles total", imported, total)
    return total

expand_queries

expand_queries(config, *, force=False)

Enrich each patient's matching summary via the CoT query expander.

No-op unless query_expansion.enabled. Loads the model once, enriches every summary, then frees it before the match stage loads its own model. Idempotent: a summary already marked query_expanded is skipped.

Source code in src/trialmatchai/orchestration.py
def expand_queries(config: Dict[str, Any], *, force: bool = False) -> int:
    """Enrich each patient's matching summary via the CoT query expander.

    No-op unless ``query_expansion.enabled``. Loads the model once, enriches
    every summary, then frees it before the match stage loads its own model.
    Idempotent: a summary already marked ``query_expanded`` is skipped.
    """
    from trialmatchai.matching.query_expansion import build_query_expander, enrich_summary

    expander = build_query_expander(config)
    if expander is None:
        logger.info("Query expansion disabled; using deterministic summaries.")
        return 0

    patient_cfg = config.get("patient_inputs", {})
    profile_dir = Path(patient_cfg.get("profile_dir", "data/patients/profiles"))
    summary_dir = Path(patient_cfg.get("summary_dir", "data/patients/summaries"))
    enriched = 0
    for profile_path in sorted(profile_dir.glob("*.json")):
        pid = profile_path.stem
        summary_path = summary_dir / f"{pid}.json"
        if not summary_path.exists():
            continue
        summary = _read_json(summary_path)
        if not force and summary.get("query_expanded"):
            continue
        profile = _read_json(profile_path)
        narrative = [n.get("text", "") for n in profile.get("notes", []) if n.get("text")]
        if not narrative:
            narrative = list(summary.get("patient_narrative", []))
        expansion = expander.expand(narrative)
        merged = enrich_summary(summary, expansion)
        merged["query_expanded"] = True
        write_json_file(merged, str(summary_path))
        enriched += 1
        logger.info("Query-expanded %s", pid)

    # Release the expander wrapper; its CoT engine stays in the vLLM engine cache
    # and is reused by the match stage (shared — a single load, not two).
    del expander
    logger.info("Query-expansion stage: enriched %s summaries.", enriched)
    return enriched

build_index

build_index(
    config,
    *,
    processed_trials_folder="data/processed_trials",
    processed_criteria_folder="data/processed_criteria",
    nct_filter=None,
    force=False,
)

Build the LanceDB search tables, optionally restricted to nct_filter.

Skips when both tables already exist unless force. The backend (and thus the target db path) comes from config['search_backend'].

Source code in src/trialmatchai/orchestration.py
def build_index(
    config: Dict[str, Any],
    *,
    processed_trials_folder: str | Path = "data/processed_trials",
    processed_criteria_folder: str | Path = "data/processed_criteria",
    nct_filter: Iterable[str] | None = None,
    force: bool = False,
) -> dict[str, Any]:
    """Build the LanceDB search tables, optionally restricted to ``nct_filter``.

    Skips when both tables already exist unless ``force``. The backend (and thus
    the target db path) comes from ``config['search_backend']``.
    """
    backend = LanceDBSearchBackend.from_config(config)
    search_cfg = config.get("search_backend", {})
    trials_table = search_cfg.get("trials_table", "trials")
    criteria_table = search_cfg.get("criteria_table", "criteria")

    if (
        not force
        and backend.table_exists(trials_table)
        and backend.table_exists(criteria_table)
    ):
        logger.info("Index stage skipped: tables already present at %s", backend.db_path)
        return {"skipped": True, "db_path": str(backend.db_path)}

    nct_set = set(nct_filter) if nct_filter is not None else None
    trial_docs = list(_iter_trial_docs(Path(processed_trials_folder), nct_set))
    if not trial_docs:
        raise RuntimeError(
            f"No prepared trial documents found in {processed_trials_folder}"
            + (f" for {len(nct_set)} filtered NCT ids" if nct_set else "")
        )
    n_trials = backend.index_trials(trial_docs, recreate=True)
    logger.info("Indexed %s trial documents.", n_trials)

    criteria_docs = list(_iter_criteria_docs(Path(processed_criteria_folder), nct_set))
    if not criteria_docs:
        # Refuse to leave an inconsistent index (trials table but no criteria
        # table) where `ready_to_match` can never become true. An empty corpus is
        # a data error — the corpus is unprepared.
        raise RuntimeError(
            f"No criteria documents found in {processed_criteria_folder}"
            + (f" for the {len(nct_set)} filtered NCT ids" if nct_set else "")
            + ". The corpus appears unprepared — run `trialmatchai build` (prepare) first."
        )
    n_criteria = backend.index_criteria(criteria_docs, recreate=True)
    logger.info("Indexed %s criteria documents.", n_criteria)

    return {
        "skipped": False,
        "db_path": str(backend.db_path),
        "trials": n_trials,
        "criteria": n_criteria,
    }

run_matching

run_matching(config, *, resume=True, force=False)

Run the matching pipeline with per-patient resume.

When resuming, the expensive model stack is not even loaded if every patient is already done. The resume is additionally invalidated when the search index the matches were produced against has changed (a rebuilt corpus), so stale ranked_trials.json are not served after a re-index.

Source code in src/trialmatchai/orchestration.py
def run_matching(
    config: Dict[str, Any],
    *,
    resume: bool = True,
    force: bool = False,
) -> int:
    """Run the matching pipeline with per-patient resume.

    When resuming, the expensive model stack is not even loaded if every patient is already
    done. The resume is additionally invalidated when the search index the matches were
    produced against has changed (a rebuilt corpus), so stale ranked_trials.json are not
    served after a re-index.
    """
    use_resume = resume and not force
    corpus_fp = _match_corpus_fingerprint(config)
    if use_resume and _match_corpus_changed(config, corpus_fp):
        logger.info(
            "Match resume invalidated: search index changed since last match; re-matching."
        )
        use_resume = False
    if use_resume:
        pending, done = count_pending(config)
        if pending == 0:
            logger.info("Match stage skipped: all %s patient(s) already matched.", done)
            return 0
        logger.info("Match stage: %s pending, %s already done.", pending, done)
    # Imported lazily so the convert/index stages (and the CPU-only --index-only
    # path) do not pull in the heavy model stack that main.py imports.
    from trialmatchai.main import main_pipeline

    result = main_pipeline(config=config, resume=use_resume)
    _record_match_corpus(config, corpus_fp)
    return result

prepare_corpus

prepare_corpus(
    config,
    *,
    trials_json_folder,
    processed_trials_folder,
    processed_criteria_folder,
    force=False,
    log_every=500,
)

Embed + annotate normalized trial JSONs into processed_*; resumable.

Streams one trial at a time (bounded memory), skips trials already prepared so an interrupted build picks up where it left off, and isolates per-trial failures so one bad document cannot abort the whole corpus.

Source code in src/trialmatchai/orchestration.py
def prepare_corpus(
    config: Dict[str, Any],
    *,
    trials_json_folder: str | Path,
    processed_trials_folder: str | Path,
    processed_criteria_folder: str | Path,
    force: bool = False,
    log_every: int = 500,
) -> dict[str, int]:
    """Embed + annotate normalized trial JSONs into processed_*; resumable.

    Streams one trial at a time (bounded memory), skips trials already prepared
    so an interrupted build picks up where it left off, and isolates per-trial
    failures so one bad document cannot abort the whole corpus.
    """
    trials_json_folder = Path(trials_json_folder)
    processed_trials_folder = Path(processed_trials_folder)
    processed_criteria_folder = Path(processed_criteria_folder)

    all_paths = sorted(trials_json_folder.glob("*.json"))
    if not all_paths:
        raise RuntimeError(f"No trial JSON files found to prepare in {trials_json_folder}")

    pending = [
        p
        for p in all_paths
        if force
        or not is_valid_json_file(str(processed_trials_folder / f"{p.stem}.json"))
    ]
    skipped = len(all_paths) - len(pending)
    logger.info(
        "Prepare: %s trials total, %s already prepared, %s to process.",
        len(all_paths),
        skipped,
        len(pending),
    )
    if not pending:
        return {"total": len(all_paths), "prepared": 0, "skipped": skipped, "failed": 0}

    from trialmatchai.entities import build_entity_annotator
    from trialmatchai.models.embedding import build_embedder
    from trialmatchai.registry.preparation import (
        prepare_criteria_documents,
        prepare_trial_document,
        write_prepared_criteria,
        write_prepared_trial,
    )

    processed_trials_folder.mkdir(parents=True, exist_ok=True)
    processed_criteria_folder.mkdir(parents=True, exist_ok=True)
    embedder = build_embedder(config)
    entity_annotator = build_entity_annotator(config, embedder=embedder)

    prepared = failed = 0
    for i, path in enumerate(pending, start=1):
        try:
            doc = _read_json(path)
            trial_row = prepare_trial_document(doc, embedder)
            criteria_rows = prepare_criteria_documents(
                doc, embedder, entity_annotator=entity_annotator
            )
            # Write criteria first; the trial JSON is written last and is the
            # per-trial completion marker the resume check keys on — so an
            # interrupted trial (criteria written, trial not) is re-processed
            # rather than wrongly skipped.
            write_prepared_criteria(criteria_rows, processed_criteria_folder)
            write_prepared_trial(trial_row, processed_trials_folder)
            prepared += 1
        except Exception:
            failed += 1
            logger.exception("Prepare failed for %s (continuing)", path.name)
        if i % log_every == 0:
            logger.info("Prepare progress: %s/%s done, %s failed.", i, len(pending), failed)

    logger.info(
        "Prepare complete: %s prepared, %s skipped, %s failed.", prepared, skipped, failed
    )
    return {"total": len(all_paths), "prepared": prepared, "skipped": skipped, "failed": failed}

build_system

build_system(
    config,
    *,
    trials_json_folder=None,
    processed_trials_folder="data/processed_trials",
    processed_criteria_folder="data/processed_criteria",
    force_prepare=False,
    force_reindex=False,
    link_concepts=False,
)

Run the setup half (prepare -> link -> index), idempotent, with a manifest.

Each stage is resumable and recorded in .trialmatchai_build.json next to the processed data, so a disrupted build can be re-run and continues from the last completed work.

Source code in src/trialmatchai/orchestration.py
def build_system(
    config: Dict[str, Any],
    *,
    trials_json_folder: str | Path | None = None,
    processed_trials_folder: str | Path = "data/processed_trials",
    processed_criteria_folder: str | Path = "data/processed_criteria",
    force_prepare: bool = False,
    force_reindex: bool = False,
    link_concepts: bool = False,
) -> dict:
    """Run the setup half (prepare -> link -> index), idempotent, with a manifest.

    Each stage is resumable and recorded in ``.trialmatchai_build.json`` next to
    the processed data, so a disrupted build can be re-run and continues from the
    last completed work.
    """
    paths = config.get("paths", {})
    trials_json_folder = Path(trials_json_folder or paths.get("trials_json_folder", "data/trials_jsons"))
    pt = Path(processed_trials_folder)
    pc = Path(processed_criteria_folder)
    manifest_path = _manifest_path(pt)
    manifest = _load_manifest(manifest_path)
    manifest["started_at"] = _now_iso()

    # Stage 1 — prepare embeddings/entities (resumable, GPU).
    have_prepared = _count_json(pt) > 0
    have_source = trials_json_folder.exists() and any(trials_json_folder.glob("*.json"))
    logger.info("=== build: prepare stage ===")
    # Skip the whole stage when its inputs (source corpus path+size+mtime), config, and code
    # version are unchanged since the last completed prepare -- no per-trial rescan/model load.
    prepare_fp = digest(
        _PREPARE_STATE_VERSION, dir_fingerprint(trials_json_folder), _prepare_signature(config)
    )
    if not force_prepare and stage_is_current(
        manifest.get("prepare"), fingerprint=prepare_fp, output_present=have_prepared
    ):
        logger.info(
            "Prepare stage skipped: source corpus + config + code unchanged (fingerprint match)."
        )
    elif have_source:
        # prepare_corpus internally skips already-prepared trials, so calling it
        # whenever source exists safely resumes without redoing finished work.
        stats = prepare_corpus(
            config,
            trials_json_folder=trials_json_folder,
            processed_trials_folder=pt,
            processed_criteria_folder=pc,
            force=force_prepare,
        )
        manifest["prepare"] = {
            **stats,
            "status": "complete",
            "fingerprint": prepare_fp,
            "output_fingerprint": dir_fingerprint(pt),
            "completed_at": _now_iso(),
        }
    elif have_prepared:
        logger.info("Prepare skipped: %s already populated (no trials_jsons source).", pt)
        manifest["prepare"] = {
            "skipped_existing": True,
            "status": "complete",
            "fingerprint": prepare_fp,
            "output_fingerprint": dir_fingerprint(pt),
            "completed_at": _now_iso(),
        }
    else:
        raise RuntimeError(
            f"Nothing to prepare: {pt} is empty and no trial JSONs at "
            f"{trials_json_folder}. Run `trialmatchai bootstrap-data` or provide "
            "normalized trial JSONs."
        )
    _save_manifest(manifest_path, manifest)

    # Stage 1b — link extracted entities to concept IDs when a concept store was
    # built this run, so the index carries concept IDs instead of leaving every
    # entity at concept_store_unavailable. Idempotent: already-linked entities are
    # skipped, so this also relinks an already-prepared NER-only corpus.
    if link_concepts:
        logger.info("=== build: link stage ===")
        # Chain off prepare's output fingerprint: when the prepared corpus, linker config, and
        # link code are unchanged since the last completed link, skip the whole stage instead
        # of re-reading every criterion file just to confirm it is already linked.
        upstream_fp = (manifest.get("prepare") or {}).get("output_fingerprint", "")
        link_fp = digest(_LINK_STATE_VERSION, upstream_fp, _linker_signature(config))
        if not force_prepare and stage_is_current(
            manifest.get("link"), fingerprint=link_fp, output_present=_count_subdirs(pc) > 0
        ):
            logger.info(
                "Link stage skipped: prepared corpus + linker config + code unchanged "
                "(fingerprint match) -- not re-reading the criteria corpus."
            )
        else:
            from trialmatchai.linking import link_corpus

            link_tally = link_corpus(
                config, processed_criteria_folder=pc, processed_trials_folder=pt
            )
            manifest["link"] = {
                **link_tally,
                "status": "complete",
                "fingerprint": link_fp,
                "completed_at": _now_iso(),
            }
        _save_manifest(manifest_path, manifest)

    # Stage 2 — build the LanceDB search index.
    logger.info("=== build: index stage ===")
    index_upstream = (manifest.get("link") or {}).get("fingerprint") or (
        manifest.get("prepare") or {}
    ).get("output_fingerprint", "")
    index_fp = digest(_INDEX_STATE_VERSION, index_upstream, _index_signature(config))
    if not force_reindex and stage_is_current(
        manifest.get("index"), fingerprint=index_fp, output_present=_index_tables_present(config)
    ):
        logger.info(
            "Index stage skipped: corpus + index config + code unchanged (fingerprint match)."
        )
        index_info = dict(manifest["index"])
    else:
        # Fingerprint changed (or forced): rebuild so the index reflects the CURRENT corpus.
        # The old table-existence-only skip would keep a stale index after the corpus changed.
        index_info = build_index(
            config, processed_trials_folder=pt, processed_criteria_folder=pc, force=True
        )
        manifest["index"] = {
            **index_info,
            "status": "complete",
            "fingerprint": index_fp,
            "completed_at": _now_iso(),
        }
    _save_manifest(manifest_path, manifest)

    state = build_state(
        config,
        processed_trials_folder=pt,
        processed_criteria_folder=pc,
    )
    manifest["state"] = state
    manifest["completed_at"] = _now_iso()
    _save_manifest(manifest_path, manifest)
    logger.info(
        "Build complete — ready_to_match=%s. Manifest: %s",
        state["ready_to_match"],
        manifest_path,
    )
    return manifest

build_state

build_state(
    config,
    *,
    processed_trials_folder="data/processed_trials",
    processed_criteria_folder="data/processed_criteria",
)

Report what the build half has produced — used by build --status.

Source code in src/trialmatchai/orchestration.py
def build_state(
    config: Dict[str, Any],
    *,
    processed_trials_folder: str | Path = "data/processed_trials",
    processed_criteria_folder: str | Path = "data/processed_criteria",
) -> dict:
    """Report what the build half has produced — used by `build --status`."""
    pt = Path(processed_trials_folder)
    pc = Path(processed_criteria_folder)
    search_cfg = config.get("search_backend", {})
    backend = LanceDBSearchBackend.from_config(config)
    trials_table = backend.table_exists(search_cfg.get("trials_table", "trials"))
    criteria_table = backend.table_exists(search_cfg.get("criteria_table", "criteria"))
    linker = config.get("concept_linker", {})
    concepts_path = Path(linker.get("db_path", "data/concepts"))
    concepts_present = concepts_path.exists() and any(concepts_path.iterdir())
    return {
        "processed_trials": {"folder": str(pt), "count": _count_json(pt)},
        "processed_criteria": {"folder": str(pc), "count": _count_subdirs(pc)},
        "index": {
            "db_path": str(backend.db_path),
            "trials_table": trials_table,
            "criteria_table": criteria_table,
        },
        "concepts": {"db_path": str(concepts_path), "present": bool(concepts_present)},
        "ready_to_match": bool(trials_table and criteria_table),
    }

Registry updater

trialmatchai.registry.updater

RegistryUpdater

Source code in src/trialmatchai/registry/updater.py
class RegistryUpdater:
    def __init__(
        self,
        *,
        client: ClinicalTrialsGovClient,
        backend: LanceDBSearchBackend,
        embedder: TextEmbeddingBackend,
        entity_annotator: EntityAnnotationBackend | None = None,
    ) -> None:
        self.client = client
        self.backend = backend
        self.embedder = embedder
        self.entity_annotator = entity_annotator

    def run(self, config: RegistryUpdateConfig) -> RegistryUpdateReport:
        report = RegistryUpdateReport(dry_run=config.dry_run)
        manifest = RegistryManifest(config.manifest_path)
        latest = manifest.load_latest()
        seen: set[str] = set()

        for keyword in config.keywords:
            remaining = _remaining(config.max_studies, report.fetched)
            if remaining == 0:
                break
            logger.info("Fetching registry studies for keyword: %s", keyword)
            try:
                for study in self.client.iter_studies(
                    keyword=keyword,
                    statuses=config.statuses,
                    since=config.since,
                    max_studies=remaining,
                ):
                    if (
                        config.max_studies is not None
                        and report.fetched >= config.max_studies
                    ):
                        break
                    self._process_study(
                        study,
                        config=config,
                        manifest=manifest,
                        latest=latest,
                        seen=seen,
                        report=report,
                    )
            except Exception as exc:
                logger.exception("Registry source fetch failed for keyword: %s", keyword)
                report.failed += 1
                report.failures.append(
                    RegistryStudyFailure(
                        nct_id=None,
                        error=f"{keyword}: {exc}",
                    )
                )

        if not config.dry_run:
            self.write_run_report(config.reports_dir, report)
        return report

    def _process_study(
        self,
        study: dict[str, Any],
        *,
        config: RegistryUpdateConfig,
        manifest: RegistryManifest,
        latest: dict[str, ManifestRecord],
        seen: set[str],
        report: RegistryUpdateReport,
    ) -> None:
        nct_id: str | None = None
        try:
            normalized = normalize_study(study)
            nct_id = str(normalized["nct_id"])
            if nct_id in seen:
                report.duplicate += 1
                return
            seen.add(nct_id)
            report.fetched += 1

            digest = source_hash(study)
            previous = latest.get(nct_id)
            if previous and previous.source_hash == digest:
                report.unchanged += 1
                return

            is_new = previous is None
            if is_new:
                report.new += 1
            else:
                report.changed += 1

            if config.dry_run:
                return

            self._write_source_and_normalized(study, normalized, config=config)
            if config.reindex_all_changed:
                prepared_trial = prepare_trial_document(normalized, self.embedder)
                prepared_criteria = prepare_criteria_documents(
                    normalized,
                    self.embedder,
                    entity_annotator=self.entity_annotator,
                )
                report.indexed += self.backend.upsert_trials([prepared_trial])
                report.criteria_indexed += self.backend.replace_criteria_for_trials(
                    [nct_id],
                    prepared_criteria,
                )

            record = ManifestRecord(
                nct_id=nct_id,
                source_url=str(normalized.get("source_url", "")),
                source_hash=digest,
                fetched_at=utc_now_iso(),
                last_update_posted=normalized.get("last_update_posted"),
                processing_status="indexed" if config.reindex_all_changed else "fetched",
            )
            manifest.append(record)
            latest[nct_id] = record
        except Exception as exc:
            logger.exception("Registry update failed for study %s", nct_id or "<unknown>")
            report.failed += 1
            report.failures.append(
                RegistryStudyFailure(nct_id=nct_id, error=str(exc))
            )
            if not config.dry_run and nct_id:
                manifest.append(
                    ManifestRecord(
                        nct_id=nct_id,
                        source_url=f"https://clinicaltrials.gov/study/{nct_id}",
                        source_hash=source_hash(study),
                        fetched_at=utc_now_iso(),
                        last_update_posted=None,
                        processing_status="failed",
                        error_summary=str(exc),
                    )
                )

    def _write_source_and_normalized(
        self,
        study: dict[str, Any],
        normalized: dict[str, Any],
        *,
        config: RegistryUpdateConfig,
    ) -> None:
        nct_id = str(normalized["nct_id"])
        _write_json(config.raw_dir / f"{nct_id}.json", study)
        _write_json(config.normalized_trials_dir / f"{nct_id}.json", normalized)

    @staticmethod
    def write_run_report(
        reports_dir: str | Path,
        report: RegistryUpdateReport,
    ) -> Path:
        reports_path = Path(reports_dir)
        reports_path.mkdir(parents=True, exist_ok=True)
        path = reports_path / f"registry-update-{utc_now_iso().replace(':', '')}.json"
        _write_json(path, report.to_dict())
        return path

RegistryUpdateConfig dataclass

Source code in src/trialmatchai/registry/updater.py
@dataclass(frozen=True)
class RegistryUpdateConfig:
    raw_dir: Path
    normalized_trials_dir: Path
    manifest_path: Path
    reports_dir: Path
    keywords: tuple[str, ...] = DEFAULT_REGISTRY_KEYWORDS
    statuses: tuple[str, ...] = DEFAULT_REGISTRY_STATUSES
    since: date | None = None
    max_studies: int | None = None
    dry_run: bool = False
    reindex_all_changed: bool = True
    failure_threshold: float = 0.25

RegistryUpdateReport dataclass

Source code in src/trialmatchai/registry/updater.py
@dataclass
class RegistryUpdateReport:
    fetched: int = 0
    new: int = 0
    changed: int = 0
    unchanged: int = 0
    failed: int = 0
    duplicate: int = 0
    indexed: int = 0
    criteria_indexed: int = 0
    dry_run: bool = False
    failures: list[RegistryStudyFailure] = field(default_factory=list)

    @property
    def failure_rate(self) -> float:
        denominator = max(1, self.fetched)
        return self.failed / denominator

    def to_dict(self) -> dict[str, Any]:
        data = asdict(self)
        data["failure_rate"] = self.failure_rate
        return data

Evaluation metrics

trialmatchai.trec.metrics

Ranking-quality metrics for TREC evaluation.

These complement recall@k (the retrieval-side metric in qrels). nDCG here is:

  • tie-aware (McSherry & Najork, 2008): trials sharing the same ranking score form a tie group, and each member is given the AVERAGE positional discount over the ranks the group spans (truncated at k). The result is the EXPECTED nDCG over all random orderings of the tied trials, so it is invariant to arbitrary tie-breaking — it rewards only genuinely ordering a more-relevant trial above a less-relevant one.
  • condensed: computed over the labeled-and-retrieved trials only, with the IDCG normalized to that same set. It measures the quality of the final ranking of the trials the model actually evaluated, decoupled from recall.

Gain is linear (gain = relevance grade), matching trec_eval's default and the legacy evaluation.

ndcg_at_k

ndcg_at_k(ordered_ids, score_of, gain_of, k)

Tie-aware nDCG@k. ordered_ids should be the condensed (labeled) list.

Source code in src/trialmatchai/trec/metrics.py
def ndcg_at_k(
    ordered_ids: Sequence[str],
    score_of: Mapping[str, float],
    gain_of: Mapping[str, float],
    k: int,
) -> float:
    """Tie-aware nDCG@k. ``ordered_ids`` should be the condensed (labeled) list."""
    if k <= 0 or not ordered_ids:
        return 0.0
    idcg = idcg_at_k([gain_of.get(d, 0.0) for d in ordered_ids], k)
    if idcg <= 0:
        return 0.0
    return tie_aware_dcg_at_k(ordered_ids, score_of, gain_of, k) / idcg

condensed_ndcg

condensed_ndcg(ranked_ids, score_of, grade_of, cutoffs)

Tie-aware nDCG@k for each cutoff, condensed to labeled-and-retrieved trials.

ranked_ids is the final ranking order; grade_of is the qrels grade for judged trials. Only trials present in grade_of are kept (condensed).

Source code in src/trialmatchai/trec/metrics.py
def condensed_ndcg(
    ranked_ids: Sequence[str],
    score_of: Mapping[str, float],
    grade_of: Mapping[str, int],
    cutoffs: Sequence[int],
) -> Dict[int, float]:
    """Tie-aware nDCG@k for each cutoff, condensed to labeled-and-retrieved trials.

    ``ranked_ids`` is the final ranking order; ``grade_of`` is the qrels grade for
    judged trials. Only trials present in ``grade_of`` are kept (condensed).
    """
    condensed = [nid for nid in ranked_ids if nid in grade_of]
    return {k: ndcg_at_k(condensed, score_of, grade_of, k) for k in cutoffs}

precision_at_k

precision_at_k(ordered_ids, relevant, k)

Standard binary P@k over the final ranked list (hard cutoff k).

Source code in src/trialmatchai/trec/metrics.py
def precision_at_k(ordered_ids: Sequence[str], relevant: Set[str], k: int) -> float:
    """Standard binary P@k over the final ranked list (hard cutoff k)."""
    if k <= 0:
        return 0.0
    topk = ordered_ids[:k]
    if not topk:
        return 0.0
    return sum(1 for d in topk if d in relevant) / float(k)

trialmatchai.trec.qrels

Official TREC relevance judgments (qrels): download, parse, corpus, metrics.

The per-track NCT corpus pool is derived directly from the qrels (the set of judged trials) — replacing the previously-checked-in Unique_NCT_IDs lists. Evaluation computes recall@k of the retrieval against the same qrels.

TREC Clinical Trials relevance grades: 0 = not relevant, 1 = excluded (the trial matches the condition but the patient is excluded), 2 = eligible. By default a trial counts as relevant at grade >= 1 (matching the legacy recall evaluation); pass threshold=2 to score eligible-only.

parse_qrels

parse_qrels(path, id_prefix)

Parse a TREC qrels file into {query_id: {nct_id: relevance}}.

Lines are <topic> <iteration> <nct_id> <relevance> (whitespace separated). The query id is f"{id_prefix}{topic}" to match the imported topic ids and the per-patient results folders.

Source code in src/trialmatchai/trec/qrels.py
def parse_qrels(path: Path, id_prefix: str) -> dict[str, dict[str, int]]:
    """Parse a TREC qrels file into {query_id: {nct_id: relevance}}.

    Lines are ``<topic> <iteration> <nct_id> <relevance>`` (whitespace
    separated). The query id is ``f"{id_prefix}{topic}"`` to match the imported
    topic ids and the per-patient results folders.
    """
    qrels: dict[str, dict[str, int]] = {}
    for raw in Path(path).read_text(encoding="utf-8", errors="ignore").splitlines():
        parts = raw.split()
        if len(parts) < 4:
            continue
        topic, _iteration, nct_id, relevance = parts[0], parts[1], parts[2], parts[3]
        try:
            rel = int(relevance)
        except ValueError:
            continue
        query_id = f"{id_prefix}{topic.strip()}"
        qrels.setdefault(query_id, {})[nct_id.strip()] = rel
    if not qrels:
        raise ValueError(f"No judgments parsed from qrels file {path}")
    return qrels

corpus_ncts

corpus_ncts(qrels)

The judged-trial pool across all queries (used to restrict the index).

Source code in src/trialmatchai/trec/qrels.py
def corpus_ncts(qrels: dict[str, dict[str, int]]) -> set[str]:
    """The judged-trial pool across all queries (used to restrict the index)."""
    pool: set[str] = set()
    for judgments in qrels.values():
        pool.update(judgments)
    return pool

evaluate

evaluate(
    qrels,
    results_dir,
    *,
    cutoffs=DEFAULT_CUTOFFS,
    threshold=1,
)

Per-query and mean metrics over the patients in results_dir.

Two complementary families
  • recall@k — retrieval quality (first-level candidate list).
  • tie-aware nDCG@{5,10,20} + P@10 — ranking quality of the final ranked_trials.json, condensed to labeled-and-retrieved trials. nDCG is order-invariant on ties (McSherry-Najork); P@10 is reported for both "relevant" (grade>=1) and "eligible" (grade==2).
Source code in src/trialmatchai/trec/qrels.py
def evaluate(
    qrels: dict[str, dict[str, int]],
    results_dir: Path,
    *,
    cutoffs: tuple[int, ...] = DEFAULT_CUTOFFS,
    threshold: int = 1,
) -> dict:
    """Per-query and mean metrics over the patients in ``results_dir``.

    Two complementary families:
      * recall@k       — retrieval quality (first-level candidate list).
      * tie-aware nDCG@{5,10,20} + P@10 — ranking quality of the final
        ranked_trials.json, condensed to labeled-and-retrieved trials. nDCG is
        order-invariant on ties (McSherry-Najork); P@10 is reported for both
        "relevant" (grade>=1) and "eligible" (grade==2).
    """
    results_dir = Path(results_dir)
    relevant = relevant_ncts(qrels, threshold=threshold)
    eligible = relevant_ncts(qrels, threshold=2)
    per_query: dict[str, dict] = {}

    rec_sums = {f"recall@{k}": 0.0 for k in cutoffs}
    rec_counts = {f"recall@{k}": 0 for k in cutoffs}
    rank_sums = {f"ndcg@{k}": 0.0 for k in NDCG_CUTOFFS}
    rank_sums[f"P@{P_CUTOFF}(rel>=1)"] = 0.0
    rank_sums[f"P@{P_CUTOFF}(eligible)"] = 0.0
    rank_counts = {key: 0 for key in rank_sums}

    for query_id, judgments in qrels.items():
        patient_dir = results_dir / query_id
        rel_set = relevant.get(query_id, set())
        if not patient_dir.is_dir() or not rel_set:
            continue
        retrieved = _retrieved_for_patient(patient_dir)
        ranked, score_of = _ranked_with_scores(patient_dir)
        row = {
            "num_relevant": len(rel_set),
            "num_retrieved": len(retrieved),
            "num_ranked": len(ranked),
        }
        for k in cutoffs:
            r = recall_at_k(retrieved, rel_set, k)
            row[f"recall@{k}"] = r
            if r is not None:
                rec_sums[f"recall@{k}"] += r
                rec_counts[f"recall@{k}"] += 1

        if ranked:
            ndcg = condensed_ndcg(ranked, score_of, judgments, NDCG_CUTOFFS)
            for k in NDCG_CUTOFFS:
                row[f"ndcg@{k}"] = ndcg[k]
                rank_sums[f"ndcg@{k}"] += ndcg[k]
                rank_counts[f"ndcg@{k}"] += 1
            p_rel = precision_at_k(ranked, rel_set, P_CUTOFF)
            p_elig = precision_at_k(ranked, eligible.get(query_id, set()), P_CUTOFF)
            row[f"P@{P_CUTOFF}(rel>=1)"] = p_rel
            row[f"P@{P_CUTOFF}(eligible)"] = p_elig
            rank_sums[f"P@{P_CUTOFF}(rel>=1)"] += p_rel
            rank_sums[f"P@{P_CUTOFF}(eligible)"] += p_elig
            rank_counts[f"P@{P_CUTOFF}(rel>=1)"] += 1
            rank_counts[f"P@{P_CUTOFF}(eligible)"] += 1
        per_query[query_id] = row

    mean = {**_mean(rec_sums, rec_counts), **_mean(rank_sums, rank_counts)}
    return {
        "recall_relevance_threshold": threshold,
        "num_queries_scored": len(per_query),
        "num_queries_ranked": rank_counts[f"ndcg@{NDCG_CUTOFFS[0]}"],
        "mean": mean,
        "per_query": per_query,
    }

HTML report

trialmatchai.interop.exporters.html_report

Self-contained HTML results report for a matched patient.

Joins ranked_trials.json + the per-trial CoT eligibility evaluations + trial metadata + the patient matching summary into one offline report.html (no server, no build step). All dynamic content is embedded as a JSON island and rendered client-side via safe DOM APIs, so there is no server-side HTML templating to escape and no templating dependency.

build_report_model

build_report_model(
    *,
    patient_summary,
    ranked,
    eligibility_by_id,
    meta_by_id,
    cot_by_id=None,
    generated_at,
    run_info=None,
)

Pure join of a patient's result artifacts into a render-ready model.

No I/O — the caller supplies already-loaded data, so this is unit-testable. Trials keep ranked_trials.json order; rank is the 1-based position.

Source code in src/trialmatchai/interop/exporters/html_report.py
def build_report_model(
    *,
    patient_summary: Mapping[str, Any],
    ranked: Any,
    eligibility_by_id: Mapping[str, Any],
    meta_by_id: Mapping[str, Any],
    cot_by_id: Mapping[str, Any] | None = None,
    generated_at: str,
    run_info: Mapping[str, Any] | None = None,
) -> dict:
    """Pure join of a patient's result artifacts into a render-ready model.

    No I/O — the caller supplies already-loaded data, so this is unit-testable.
    Trials keep ``ranked_trials.json`` order; ``rank`` is the 1-based position.
    """
    trials: list[dict] = []
    for rank, rec in enumerate(_ranked_records(ranked), start=1):
        tid = str(rec.get("TrialID", "")).strip()
        if not tid:
            continue
        elig = eligibility_by_id.get(tid) or {}
        # an unparseable model response left an error sentinel, not an evaluation
        reasoning_ok = bool(elig) and elig.get("error") != _ERROR_SENTINEL
        meta = _trial_meta(meta_by_id.get(tid))
        trials.append(
            {
                "rank": rank,
                "trial_id": tid,
                "score": rec.get("Score"),
                "reranker_score": rec.get("RerankerScore"),
                "first_level_score": rec.get("FirstLevelScore"),
                "meta": meta,
                "metadata_available": bool(meta),
                "final_decision": elig.get("Final Decision") if reasoning_ok else None,
                "inclusion": _criteria(elig.get("Inclusion_Criteria_Evaluation")) if reasoning_ok else [],
                "exclusion": _criteria(elig.get("Exclusion_Criteria_Evaluation")) if reasoning_ok else [],
                "reasoning_available": reasoning_ok,
                "cot": (cot_by_id or {}).get(tid),
            }
        )
    return {
        "patient": {
            "id": patient_summary.get("patient_id"),
            "age": patient_summary.get("age"),
            "sex": patient_summary.get("gender"),
            "main_conditions": list(patient_summary.get("main_conditions", [])),
            "other_conditions": list(patient_summary.get("other_conditions", [])),
            "narrative": _narrative(patient_summary.get("patient_narrative")),
        },
        "trials": trials,
        "generated_at": generated_at,
        "run": dict(run_info or {}),
    }

profile_to_model

profile_to_model(
    patient_dir,
    *,
    summary_dir=None,
    trial_meta_folders=None,
    generated_at=None,
    run_info=None,
)

Read a patient's result dir into a render-ready model (no HTML).

patient_dir is <output_dir>/<patient_id>/. Metadata folders are tried in order (processed_trials, then trials_jsons); ids with no metadata (e.g. Dutch NL… registry trials) degrade to id + score + verdict only.

Source code in src/trialmatchai/interop/exporters/html_report.py
def profile_to_model(
    patient_dir: str | Path,
    *,
    summary_dir: str | Path | None = None,
    trial_meta_folders: list[str | Path] | None = None,
    generated_at: str | None = None,
    run_info: Mapping[str, Any] | None = None,
) -> dict:
    """Read a patient's result dir into a render-ready model (no HTML).

    ``patient_dir`` is ``<output_dir>/<patient_id>/``. Metadata folders are tried
    in order (processed_trials, then trials_jsons); ids with no metadata (e.g.
    Dutch ``NL…`` registry trials) degrade to id + score + verdict only.
    """
    patient_dir = Path(patient_dir)
    ranked = _read_json(patient_dir / "ranked_trials.json") or {}
    records = _ranked_records(ranked)
    folders = [Path(f) for f in (trial_meta_folders or ["data/processed_trials", "data/trials_jsons"])]

    eligibility_by_id: dict[str, Any] = {}
    meta_by_id: dict[str, Any] = {}
    cot_by_id: dict[str, Any] = {}
    for rec in records:
        tid = str(rec.get("TrialID", "")).strip()
        if not tid:
            continue
        elig = _read_json(patient_dir / f"{tid}.json")
        if elig is not None:
            eligibility_by_id[tid] = elig
        if tid not in meta_by_id:
            meta = _load_meta(tid, folders)
            if meta:
                meta_by_id[tid] = meta
        cot = _read_cot(patient_dir / f"{tid}.txt")
        if cot:
            cot_by_id[tid] = cot

    patient_id = patient_dir.name
    summary = _read_json(Path(summary_dir) / f"{patient_id}.json") if summary_dir else None
    summary = summary or _read_json(patient_dir / "keywords.json")
    if not isinstance(summary, dict):
        summary = {}
    summary.setdefault("patient_id", patient_id)  # never let the id be null (drives routing)

    return build_report_model(
        patient_summary=summary,
        ranked=ranked,
        eligibility_by_id=eligibility_by_id,
        meta_by_id=meta_by_id,
        cot_by_id=cot_by_id,
        generated_at=generated_at or datetime.now().strftime("%Y-%m-%d %H:%M"),
        run_info=run_info,
    )

profile_to_html_report

profile_to_html_report(
    patient_dir,
    *,
    summary_dir=None,
    trial_meta_folders=None,
    generated_at=None,
    run_info=None,
)

Read a patient's result dir and return a self-contained single-patient report.

Source code in src/trialmatchai/interop/exporters/html_report.py
def profile_to_html_report(
    patient_dir: str | Path,
    *,
    summary_dir: str | Path | None = None,
    trial_meta_folders: list[str | Path] | None = None,
    generated_at: str | None = None,
    run_info: Mapping[str, Any] | None = None,
) -> str:
    """Read a patient's result dir and return a self-contained single-patient report."""
    model = profile_to_model(
        patient_dir,
        summary_dir=summary_dir,
        trial_meta_folders=trial_meta_folders,
        generated_at=generated_at,
        run_info=run_info,
    )
    return render_html_report(model)

render_unified_html

render_unified_html(patient_models, generated_at)

One self-contained report over many patients: a front page listing every patient that drills into the per-patient view client-side.

Source code in src/trialmatchai/interop/exporters/html_report.py
def render_unified_html(patient_models: Sequence[Mapping[str, Any]], generated_at: str) -> str:
    """One self-contained report over many patients: a front page listing every
    patient that drills into the per-patient view client-side."""
    return render_html_report({"patients": list(patient_models), "generated_at": generated_at})

render_html_report

render_html_report(model)

Embed the model as a tag-safe JSON island in the static template.

Accepts a single-patient model ({"patient", "trials", ...}) or a unified one ({"patients": [...]}). A single model is wrapped so the template always reads DATA.patients and a one-patient report skips the front page.

Source code in src/trialmatchai/interop/exporters/html_report.py
def render_html_report(model: Mapping[str, Any]) -> str:
    """Embed the model as a tag-safe JSON island in the static template.

    Accepts a single-patient model (``{"patient", "trials", ...}``) or a unified
    one (``{"patients": [...]}``). A single model is wrapped so the template
    always reads ``DATA.patients`` and a one-patient report skips the front page.
    """
    if "patients" not in model:
        model = {"patients": [dict(model)], "generated_at": model.get("generated_at", "")}
    data = json.dumps(model, ensure_ascii=False, default=str)
    # neutralize </script> injection — LLM free text may contain "</script>";
    # JSON.parse reads the escaped sequences back unchanged.
    data = data.replace("<", "\\u003c").replace(">", "\\u003e").replace("&", "\\u0026")
    html = _TEMPLATE.read_text(encoding="utf-8").replace(_DATA_PLACEHOLDER, data)
    return html.replace(_LOGO_PLACEHOLDER, _logo_data_uri())