Source code for scripts.calibration.tb_calibration_south_africa

"""
TB Model Calibration for South Africa Data

This script generates model outputs designed to match real South Africa TB data:
1. Case notification data (2000, 2005, 2010, 2015, 2020)
2. Age-stratified active TB prevalence from 2018 survey

The script creates both model outputs and synthetic/real data for comparison.
"""

import starsim as ss
import tbsim as mtb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import datetime
import time
import sys
import os
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Dynamically add the correct path to scripts/hiv for shared_functions import
try:
    current_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    current_dir = os.getcwd()
hiv_utils_path = os.path.abspath(os.path.join(current_dir, '../../scripts/hiv'))
if hiv_utils_path not in sys.path:
    sys.path.insert(0, hiv_utils_path)
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
import shared_functions as sf

# Import health-seeking, diagnostic, and treatment interventions
from tbsim.interventions.tb_health_seeking import HealthSeekingBehavior
from tbsim.interventions.tb_diagnostic import TBDiagnostic
from tbsim.interventions.tb_treatment import TBTreatment


[docs] def create_south_africa_data(): """ Create synthetic South Africa TB data for calibration Returns: dict: Dictionary containing synthetic data for case notifications and prevalence """ # Synthetic case notification data (per 100,000 population) # Based on WHO Global TB Reports and South Africa TB reports case_notification_data = { 'year': [2000, 2005, 2010, 2015, 2020], 'rate_per_100k': [650, 950, 980, 834, 554], # Declining trend due to improved control 'total_cases': [280000, 450000, 490000, 450000, 320000], # Estimated total cases 'source': ['WHO Global TB Report', 'WHO Global TB Report', 'WHO Global TB Report', 'WHO Global TB Report', 'WHO Global TB Report'] } # Age-stratified active TB prevalence from 2018 survey (per 100,000) # Based on South Africa TB Prevalence Survey 2018 age_prevalence_data = { 'age_group': ['15-24', '25-34', '35-44', '45-54', '55-64', '65+'], 'prevalence_per_100k': [850, 1200, 1400, 1600, 1800, 2200], # Higher in older age groups 'prevalence_percent': [0.85, 1.20, 1.40, 1.60, 1.80, 2.20], 'sample_size': [5000, 4500, 4000, 3500, 3000, 2500], # Survey sample sizes 'source': ['SA TB Prevalence Survey 2018'] * 6 } # Additional calibration targets calibration_targets = { 'overall_prevalence_2018': 0.852, # 0.852% from survey 'hiv_coinfection_rate': 0.60, # 60% of TB cases are HIV-positive 'case_detection_rate': 0.65, # 65% of cases are detected 'treatment_success_rate': 0.78, # 78% treatment success rate 'mortality_rate': 0.12, # 12% case fatality rate } return { 'case_notifications': pd.DataFrame(case_notification_data), 'age_prevalence': pd.DataFrame(age_prevalence_data), 'targets': calibration_targets }
[docs] def compute_age_stratified_prevalence(sim, target_year=2018): """ Compute age-stratified TB prevalence from simulation results Args: sim: Simulation object target_year: Year to compute prevalence for Returns: dict: Age-stratified prevalence data """ # Find the time index closest to target year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get people alive at target time people = sim.people alive_mask = people.alive # Get TB states tb_states = sim.diseases.tb.state active_tb_mask = np.isin(tb_states, [mtb.TBS.ACTIVE_SMPOS, mtb.TBS.ACTIVE_SMNEG, mtb.TBS.ACTIVE_EXPTB]) # Get ages at target time ages = people.age[alive_mask] active_tb_ages = people.age[alive_mask & active_tb_mask] # Define age groups age_groups = [(15, 24), (25, 34), (35, 44), (45, 54), (55, 64), (65, 200)] age_group_labels = ['15-24', '25-34', '35-44', '45-54', '55-64', '65+'] prevalence_by_age = {} for i, (min_age, max_age) in enumerate(age_groups): # Count people in age group age_mask = (ages >= min_age) & (ages <= max_age) total_in_age_group = np.sum(age_mask) # Count active TB cases in age group age_tb_mask = (active_tb_ages >= min_age) & (active_tb_ages <= max_age) tb_in_age_group = np.sum(age_tb_mask) # Calculate prevalence if total_in_age_group > 0: prevalence = tb_in_age_group / total_in_age_group prevalence_per_100k = prevalence * 100000 else: prevalence = 0 prevalence_per_100k = 0 prevalence_by_age[age_group_labels[i]] = { 'prevalence': prevalence, 'prevalence_per_100k': prevalence_per_100k, 'total_people': total_in_age_group, 'tb_cases': tb_in_age_group } return prevalence_by_age
[docs] def compute_case_notifications(sim, target_years=[2000, 2005, 2010, 2015, 2020]): """ Compute case notifications from simulation results Args: sim: Simulation object target_years: Years to compute notifications for Returns: dict: Case notification data by year """ time_years = np.array([d.year for d in sim.results['timevec']]) notifications_by_year = {} for target_year in target_years: # Find the time index closest to target year target_idx = np.argmin(np.abs(time_years - target_year)) # Get diagnostic results for that year tbdiag = sim.results['tbdiagnostic'] n_diagnosed = tbdiag['n_test_positive'].values[target_idx] # Get population size for rate calculation n_alive = sim.results['n_alive'][target_idx] # Calculate rate per 100,000 rate_per_100k = (n_diagnosed / n_alive) * 100000 if n_alive > 0 else 0 notifications_by_year[target_year] = { 'diagnosed_cases': n_diagnosed, 'rate_per_100k': rate_per_100k, 'population': n_alive } return notifications_by_year
[docs] def plot_calibration_comparison(sim, sa_data, timestamp): """ Create comprehensive calibration comparison plots Args: sim: Simulation object sa_data: South Africa data dictionary timestamp: Timestamp for file naming """ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) # 1. Case notification comparison notifications = compute_case_notifications(sim) years = list(notifications.keys()) model_rates = [notifications[year]['rate_per_100k'] for year in years] data_rates = sa_data['case_notifications']['rate_per_100k'].values ax1.plot(years, model_rates, 'bo-', label='Model Output', linewidth=2, markersize=8) ax1.plot(years, data_rates, 'ro-', label='South Africa Data', linewidth=2, markersize=8) ax1.set_xlabel('Year') ax1.set_ylabel('Case Notification Rate (per 100,000)') ax1.set_title('TB Case Notifications: Model vs Data') ax1.legend() ax1.grid(True, alpha=0.3) # Add percentage difference for i, year in enumerate(years): if data_rates[i] > 0: pct_diff = ((model_rates[i] - data_rates[i]) / data_rates[i]) * 100 ax1.annotate(f'{pct_diff:.1f}%', xy=(year, model_rates[i]), xytext=(0, 10), textcoords='offset points', ha='center', fontsize=8) # 2. Age-stratified prevalence comparison age_prevalence = compute_age_stratified_prevalence(sim) age_groups = list(age_prevalence.keys()) model_prevalence = [age_prevalence[group]['prevalence_per_100k'] for group in age_groups] data_prevalence = sa_data['age_prevalence']['prevalence_per_100k'].values x_pos = np.arange(len(age_groups)) width = 0.35 ax2.bar(x_pos - width/2, model_prevalence, width, label='Model Output', alpha=0.8) ax2.bar(x_pos + width/2, data_prevalence, width, label='South Africa Data', alpha=0.8) ax2.set_xlabel('Age Group') ax2.set_ylabel('Active TB Prevalence (per 100,000)') ax2.set_title('Age-Stratified TB Prevalence: Model vs Data (2018)') ax2.set_xticks(x_pos) ax2.set_xticklabels(age_groups) ax2.legend() ax2.grid(True, alpha=0.3) # Add percentage differences for i, (model_val, data_val) in enumerate(zip(model_prevalence, data_prevalence)): if data_val > 0: pct_diff = ((model_val - data_val) / data_val) * 100 ax2.annotate(f'{pct_diff:.1f}%', xy=(i, max(model_val, data_val)), xytext=(0, 5), textcoords='offset points', ha='center', fontsize=8) # 3. Overall prevalence over time time_years = np.array([d.year for d in sim.results['timevec']]) active_prev = sim.results['tb']['prevalence_active'] ax3.plot(time_years, active_prev * 100, 'b-', linewidth=2, label='Model Active TB Prevalence') ax3.axhline(sa_data['targets']['overall_prevalence_2018'], color='r', linestyle='--', label=f"Target: {sa_data['targets']['overall_prevalence_2018']:.3f}%") ax3.set_xlabel('Year') ax3.set_ylabel('Active TB Prevalence (%)') ax3.set_title('Overall TB Prevalence Over Time') ax3.legend() ax3.grid(True, alpha=0.3) # 4. Diagnostic and treatment cascade tbdiag = sim.results['tbdiagnostic'] tbtx = sim.results['tbtreatment'] # Get cumulative values at the end total_diagnosed = tbdiag['cum_test_positive'].values[-1] total_treated = tbtx['cum_treatment_success'].values[-1] total_failures = tbtx['cum_treatment_failure'].values[-1] # Create cascade plot cascade_data = [total_diagnosed, total_treated, total_failures] cascade_labels = ['Diagnosed', 'Successfully Treated', 'Treatment Failures'] colors = ['skyblue', 'lightgreen', 'lightcoral'] bars = ax4.bar(cascade_labels, cascade_data, color=colors, alpha=0.8) ax4.set_ylabel('Number of People') ax4.set_title('TB Care Cascade (Cumulative)') ax4.grid(True, alpha=0.3) # Add value labels on bars for bar, value in zip(bars, cascade_data): ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(cascade_data)*0.01, f'{int(value):,}', ha='center', va='bottom', fontweight='bold') plt.tight_layout() plt.suptitle('TB Model Calibration: South Africa Data Comparison', fontsize=16, y=1.02) filename = f"tb_calibration_comparison_{timestamp}.pdf" plt.savefig(filename, dpi=300, bbox_inches='tight') plt.show() return fig
[docs] def calculate_calibration_score(sim, sa_data): """ Calculate a composite calibration score based on multiple metrics Args: sim: Simulation object sa_data: South Africa data dictionary Returns: dict: Calibration metrics and composite score """ # Compute model outputs notifications = compute_case_notifications(sim) age_prevalence = compute_age_stratified_prevalence(sim) # Case notification fit years = list(notifications.keys()) model_rates = np.array([notifications[year]['rate_per_100k'] for year in years]) data_rates = sa_data['case_notifications']['rate_per_100k'].values notification_rmse = np.sqrt(np.mean((model_rates - data_rates)**2)) notification_mape = np.mean(np.abs((model_rates - data_rates) / data_rates)) * 100 # Age prevalence fit age_groups = list(age_prevalence.keys()) model_age_prev = np.array([age_prevalence[group]['prevalence_per_100k'] for group in age_groups]) data_age_prev = sa_data['age_prevalence']['prevalence_per_100k'].values age_prev_rmse = np.sqrt(np.mean((model_age_prev - data_age_prev)**2)) age_prev_mape = np.mean(np.abs((model_age_prev - data_age_prev) / data_age_prev)) * 100 # Overall prevalence fit time_years = np.array([d.year for d in sim.results['timevec']]) active_prev = sim.results['tb']['prevalence_active'] target_idx = np.argmin(np.abs(time_years - 2018)) model_overall_prev = active_prev[target_idx] * 100 target_overall_prev = sa_data['targets']['overall_prevalence_2018'] overall_prev_error = abs(model_overall_prev - target_overall_prev) # Composite score (lower is better) # Weight different components based on importance composite_score = ( 0.4 * notification_mape + # Case notifications (40% weight) 0.4 * age_prev_mape + # Age prevalence (40% weight) 0.2 * (overall_prev_error * 100) # Overall prevalence (20% weight) ) return { 'notification_rmse': notification_rmse, 'notification_mape': notification_mape, 'age_prev_rmse': age_prev_rmse, 'age_prev_mape': age_prev_mape, 'overall_prev_error': overall_prev_error, 'model_overall_prev': model_overall_prev, 'target_overall_prev': target_overall_prev, 'composite_score': composite_score }
[docs] def create_calibration_report(sim, sa_data, timestamp): """ Create a detailed calibration report with metrics Args: sim: Simulation object sa_data: South Africa data dictionary timestamp: Timestamp for file naming Returns: dict: Calibration metrics """ # Compute model outputs notifications = compute_case_notifications(sim) age_prevalence = compute_age_stratified_prevalence(sim) # Calculate fit metrics # Case notification fit years = list(notifications.keys()) model_rates = np.array([notifications[year]['rate_per_100k'] for year in years]) data_rates = sa_data['case_notifications']['rate_per_100k'].values notification_rmse = np.sqrt(np.mean((model_rates - data_rates)**2)) notification_mape = np.mean(np.abs((model_rates - data_rates) / data_rates)) * 100 # Age prevalence fit age_groups = list(age_prevalence.keys()) model_age_prev = np.array([age_prevalence[group]['prevalence_per_100k'] for group in age_groups]) data_age_prev = sa_data['age_prevalence']['prevalence_per_100k'].values age_prev_rmse = np.sqrt(np.mean((model_age_prev - data_age_prev)**2)) age_prev_mape = np.mean(np.abs((model_age_prev - data_age_prev) / data_age_prev)) * 100 # Overall prevalence fit time_years = np.array([d.year for d in sim.results['timevec']]) active_prev = sim.results['tb']['prevalence_active'] target_idx = np.argmin(np.abs(time_years - 2018)) model_overall_prev = active_prev[target_idx] * 100 target_overall_prev = sa_data['targets']['overall_prevalence_2018'] overall_prev_error = abs(model_overall_prev - target_overall_prev) # Create report report = { 'timestamp': timestamp, 'case_notifications': { 'model_rates': model_rates.tolist(), 'data_rates': data_rates.tolist(), 'rmse': notification_rmse, 'mape': notification_mape, 'years': years }, 'age_prevalence': { 'model_rates': model_age_prev.tolist(), 'data_rates': data_age_prev.tolist(), 'rmse': age_prev_rmse, 'mape': age_prev_mape, 'age_groups': age_groups }, 'overall_prevalence': { 'model_2018': model_overall_prev, 'target_2018': target_overall_prev, 'error': overall_prev_error }, 'model_parameters': { 'beta': sim.diseases.tb.pars.beta, 'rel_sus_latentslow': sim.diseases.tb.pars.rel_sus_latentslow, 'tb_mortality': sim.diseases.tb.pars.rate_smpos_to_dead, 'hiv_prevalence': sim.results['hiv']['hiv_prevalence'][-1] if 'hiv' in sim.results else 0 } } # Save report import json filename = f"calibration_report_{timestamp}.json" with open(filename, 'w') as f: json.dump(report, f, indent=2, default=str) # Print summary print(f"\n=== TB Model Calibration Report ===") print(f"Timestamp: {timestamp}") print(f"\nCase Notification Fit:") print(f" RMSE: {notification_rmse:.1f} per 100,000") print(f" MAPE: {notification_mape:.1f}%") print(f"\nAge Prevalence Fit:") print(f" RMSE: {age_prev_rmse:.1f} per 100,000") print(f" MAPE: {age_prev_mape:.1f}%") print(f"\nOverall Prevalence (2018):") print(f" Model: {model_overall_prev:.3f}%") print(f" Target: {target_overall_prev:.3f}%") print(f" Error: {overall_prev_error:.3f} percentage points") print(f"\nModel Parameters:") print(f" Beta: {report['model_parameters']['beta']}") print(f" Rel Sus Latent: {report['model_parameters']['rel_sus_latentslow']}") print(f" TB Mortality: {report['model_parameters']['tb_mortality']}") print(f" HIV Prevalence: {report['model_parameters']['hiv_prevalence']:.3f}") print(f"================================") return report
[docs] def run_calibration_simulation(beta=0.020, rel_sus_latentslow=0.15, tb_mortality=3e-4, seed=0, years=200, n_agents=1000): """ Run a single calibration simulation with specified parameters Args: beta: TB transmission rate rel_sus_latentslow: Relative susceptibility of latent TB tb_mortality: TB mortality rate seed: Random seed years: Simulation duration n_agents: Number of agents Returns: sim: Simulation object """ start_year = 1850 sim_pars = dict( dt=ss.days(30), start=ss.date(f'{start_year}-01-01'), stop=ss.date(f'{start_year + years}-01-01'), rand_seed=seed, verbose=0, ) # Load demographic data possible_cbr_paths = [ '../data/South_Africa_CBR.csv', 'tbsim/data/South_Africa_CBR.csv', 'data/South_Africa_CBR.csv', ] possible_asmr_paths = [ '../data/South_Africa_ASMR.csv', 'tbsim/data/South_Africa_ASMR.csv', 'data/South_Africa_ASMR.csv', ] cbr_path = None for path in possible_cbr_paths: if os.path.exists(path): cbr_path = path break if cbr_path is None: raise FileNotFoundError(f"Could not find South_Africa_CBR.csv in any of the expected locations") asmr_path = None for path in possible_asmr_paths: if os.path.exists(path): asmr_path = path break if asmr_path is None: raise FileNotFoundError(f"Could not find South_Africa_ASMR.csv in any of the expected locations") cbr = pd.read_csv(cbr_path) asmr = pd.read_csv(asmr_path) demog = [ ss.Births(birth_rate=cbr, dt=ss.days(30)), ss.Deaths(death_rate=asmr, dt=ss.days(30), rate_units=1), ] # Create population people = ss.People(n_agents=n_agents, extra_states=mtb.get_extrastates()) # TB parameters tb_pars = dict( beta=ss.per(beta, ), init_prev=ss.bernoulli(p=0.10), rel_sus_latentslow=rel_sus_latentslow, rate_LS_to_presym=ss.perday(5e-5), rate_LF_to_presym=ss.perday(8e-3), rate_active_to_clear=ss.perday(1.5e-4), rate_smpos_to_dead=ss.perday(tb_mortality), rate_exptb_to_dead=ss.perday(0.15 * tb_mortality), rate_smneg_to_dead=ss.perday(0.3 * tb_mortality), ) tb = sf.make_tb(tb_pars=tb_pars) # HIV parameters hiv_pars = dict( init_prev=ss.bernoulli(p=0.00), init_onart=ss.bernoulli(p=0.00), ) hiv = sf.make_hiv(hiv_pars=hiv_pars) # Network net = ss.RandomNet(pars=dict(n_contacts=ss.poisson(lam=5), dur=0)) # TB-HIV connector tb_hiv_connector = sf.make_tb_hiv_connector() # HIV intervention hiv_intervention = sf.make_hiv_interventions(pars=dict( mode='both', prevalence=0.20, percent_on_ART=0.50, min_age=15, max_age=60, start=ss.date(f'{start_year}-01-01'), stop=ss.date(f'{start_year + years}-01-01'), )) # Health-seeking behavior health_seeking = HealthSeekingBehavior(pars=dict( initial_care_seeking_rate=ss.perday(1/90), start=ss.date(f'{start_year}-01-01'), stop=ss.date(f'{start_year + years}-01-01'), single_use=True, )) # TB diagnostic tb_diagnostic = TBDiagnostic(pars=dict( coverage=ss.bernoulli(0.7, strict=False), sensitivity=0.60, specificity=0.95, reset_flag=False, care_seeking_multiplier=2.0, )) # TB treatment tb_treatment = TBTreatment(pars=dict( treatment_success_rate=0.70, reseek_multiplier=2.0, reset_flags=True, )) # Combine interventions all_interventions = hiv_intervention + [health_seeking, tb_diagnostic, tb_treatment] # Run simulation sim = ss.Sim( people=people, diseases=[tb, hiv], networks=net, demographics=demog, connectors=tb_hiv_connector, interventions=all_interventions, pars=sim_pars, ) sim.run() return sim
[docs] def main(): """ Main function to run the calibration analysis """ print("Starting TB Model Calibration for South Africa...") # Create South Africa data sa_data = create_south_africa_data() print("✓ Created South Africa calibration data") # Run calibration simulation print("Running calibration simulation...") sim = run_calibration_simulation( beta=0.020, rel_sus_latentslow=0.15, tb_mortality=3e-4, n_agents=1000 ) print("✓ Simulation completed") # Generate timestamp timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M") # Create calibration plots print("Creating calibration plots...") plot_calibration_comparison(sim, sa_data, timestamp) print("✓ Calibration plots created") # Create calibration report print("Creating calibration report...") report = create_calibration_report(sim, sa_data, timestamp) print("✓ Calibration report created") # Save South Africa data for reference sa_data['case_notifications'].to_csv(f"sa_case_notifications_{timestamp}.csv", index=False) sa_data['age_prevalence'].to_csv(f"sa_age_prevalence_{timestamp}.csv", index=False) print("✓ South Africa data saved") print(f"\nCalibration analysis completed!") print(f"Files created with timestamp: {timestamp}") return sim, sa_data, report
if __name__ == '__main__': sim, sa_data, report = main()