MedCodeMCP / src /parse_tabular.py
gpaasch's picture
getting local to work to prevent burining up my credits
0ef172c
raw
history blame
3.38 kB
import xml.etree.ElementTree as ET
import json
import sys
import os
from llama_index.core import VectorStoreIndex, Document, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# Update path constants
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data")
ICD_DIR = os.path.join(DATA_DIR, "icd10cm_tabular_2025")
DEFAULT_XML_PATH = os.path.join(ICD_DIR, "icd10cm_tabular_2025.xml")
PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
def main(xml_path=DEFAULT_XML_PATH):
# Create processed directory if it doesn't exist
os.makedirs(PROCESSED_DIR, exist_ok=True)
if not os.path.isfile(xml_path):
print(f"ERROR: cannot find tabular XML at '{xml_path}'")
sys.exit(1)
tree = ET.parse(xml_path)
root = tree.getroot()
icd_to_description = {}
# Iterate over every <diag> in the entire file, recursively.
# Each <diag> has:
# • <name> (the ICD-10 code)
# • <desc> (the human-readable description)
# • zero or more nested <diag> children (sub-codes).
for diag in root.iter("diag"):
name_elem = diag.find("name")
desc_elem = diag.find("desc")
if name_elem is None or desc_elem is None:
continue
# Some <diag> nodes might have <name/> or <desc/> with no text; skip those.
if name_elem.text is None or desc_elem.text is None:
continue
code = name_elem.text.strip()
description = desc_elem.text.strip()
# Only store non-empty strings:
if code and description:
icd_to_description[code] = description
# Write out a flat JSON mapping code → description
out_path = os.path.join(PROCESSED_DIR, "icd_to_description.json")
with open(out_path, "w", encoding="utf-8") as fp:
json.dump(icd_to_description, fp, indent=2, ensure_ascii=False)
print(f"Wrote {len(icd_to_description)} code entries to {out_path}")
def create_symptom_index():
# Configure to use local HuggingFace embeddings
Settings.embed_model = HuggingFaceEmbedding(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Load and process data
json_path = os.path.join(PROCESSED_DIR, "icd_to_description.json")
with open(json_path, "r") as f:
icd_data = json.load(f)
# Convert to Document objects
documents = [
Document(
text=f"ICD-10 Code {code}: {desc}",
metadata={"code": code}
)
for code, desc in icd_data.items()
]
# Create and return the index
return VectorStoreIndex.from_documents(documents)
# Move this outside the main() function
symptom_index = None
if __name__ == "__main__":
if len(sys.argv) > 1:
main(sys.argv[1])
else:
main() # Use default path
symptom_index = create_symptom_index()
# Test multiple queries
test_queries = [
"persistent cough with fever",
"severe headache with nausea",
"lower back pain",
"difficulty breathing"
]
print("\nTesting symptom matching:")
print("-" * 50)
for query in test_queries:
response = symptom_index.as_query_engine().query(query)
print(f"\nQuery: {query}")
print(f"Relevant ICD-10 codes:")
print(str(response))
print("-" * 50)