Source code for scripts.burn_in.archive.run_tb_burn_in_South_Africa

"""
Refined TB Prevalence Sweep for Manual Calibration (South Africa) with Health-Seeking, Diagnostic, and Treatment

This script performs a manual calibration sweep of TB transmission dynamics in Starsim/TBsim
to explore plausible endemic equilibrium behavior for South Africa, incorporating key 
epidemiological features specific to the South African context, including health-seeking behavior,
diagnostic testing, and treatment outcomes.

🎯 Objective:
    - Calibrate burn-in dynamics (i.e., rise and settle to endemic steady state)
    - Target approximately:
        • >50% latent TB prevalence
        • ~1% active TB prevalence
    - Qualitative fit to empirical data point: 0.852% active TB prevalence (South Africa, 2018)
    - Model realistic health-seeking behavior, diagnostic testing, and treatment outcomes

🔧 Current Assumptions:
    - Includes HIV coinfection (critical for South Africa TB dynamics)
    - Models TB-HIV interaction effects on progression rates
    - Uses South Africa-specific demographics and population structure
    - Simulation starts in 1850 and runs 200 years to allow for equilibrium
    - Incorporates historical HIV epidemic emergence (1980s onwards)
    - Health-seeking behavior with 90-day average delay (slower for better burn-in)
    - Diagnostic testing with 60% sensitivity, 70% coverage, and 95% specificity
    - Treatment with 70% success rate and retry mechanism for failures

📊 What It Does:
    - Sweeps across a grid of:
        • TB infectiousness (β)
        • Reinfection susceptibility (rel_sus_latentslow)
        • TB mortality rates
    - For each parameter combo, it:
        • Runs a simulation with TB-HIV coinfection + health-seeking + diagnostic
        • Plots active and latent prevalence over time
        • Plots health-seeking behavior metrics
        • Plots diagnostic testing outcomes
        • Plots treatment outcomes (incident and cumulative)
        • Overlays the 2018 SA data point on each plot
        • Adds an inset focused on the post-1980 period (zoomed to 0–1% active prevalence)
    - Outputs multiple PDF figures with all subplots, timestamped with run time
    - Prints runtime diagnostics including total sweep duration

📥 Inputs:
    - Hardcoded ranges for β, rel_sus_latentslow, and TB mortality
    - South Africa-specific demographic parameters
    - HIV epidemic parameters (prevalence targets, timing)
    - Health-seeking parameters (90-day average delay)
    - Diagnostic parameters (60% sensitivity, 70% coverage, 95% specificity)

📤 Outputs:
    - PDF files showing:
        • TB prevalence trajectories across parameter grid
        • Health-seeking behavior metrics
        • Diagnostic testing outcomes (incident and cumulative)
        • Treatment outcomes (incident and cumulative)
        • Population demographics
        • HIV metrics
    - Console logging of sweep progress and timing

⚠️ Notes:
    - Active prevalence <1% is sensitive to population size; low agent counts may cause extinction
    - HIV coinfection significantly impacts TB dynamics in South Africa
    - Health-seeking behavior affects TB transmission and case detection
    - Diagnostic testing influences treatment initiation and outcomes
    - Treatment success/failure affects TB transmission and care-seeking behavior
    - This model now better reflects the South African epidemiological context with realistic care-seeking and treatment

"""

import starsim as ss
import tbsim as mtb
from tbsim.comorbidities.hiv.hiv import HIVState
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import datetime
import os
import rdata
import time
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import common_functions as cf

import warnings
warnings.filterwarnings("ignore", message='Missing constructor for R class "data.table".*')

# Ensure the project root is in sys.path for package imports
import sys
import os
try:
    script_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g., in interactive mode)
    script_dir = os.getcwd()
sys.path.insert(0, os.path.abspath(os.path.join(script_dir, '..')))

# Helper function to get output path based on extension
[docs] def get_output_path(filename): ext = filename.split('.')[-1].lower() subdir = { 'pdf': '../samples/pdf/', 'csv': '../samples/csv/', 'png': '../samples/png/', 'json': '../samples/json/', 'md': '../samples/md/', }.get(ext, '../samples/') # Handle cases where __file__ is not defined (e.g., in Jupyter notebooks) try: script_dir = os.path.dirname(__file__) except NameError: # Fallback to current working directory script_dir = os.getcwd() # Ensure the directory exists outdir = os.path.join(script_dir, subdir) os.makedirs(outdir, exist_ok=True) return os.path.join(outdir, filename)
[docs] class GradualHIVIntervention(ss.Intervention): """ Custom HIV intervention that implements gradual ramp-up based on van Schalkwyk et al. 2021 data for eThekwini, South Africa. Handles both age groups: 15-24 and 25+. """
[docs] def __init__(self, pars, **kwargs): super().__init__(**kwargs) self.define_pars( percent_on_ART=0.50, # 50% of HIV-positive individuals on ART start=ss.date('1990-01-01'), stop=ss.date('2050-12-31'), ) self.update_pars(pars, **kwargs) # Define target years and prevalence levels for adults 25+ (estimated + survey data) self.hiv_targets_25plus = [ (1990, 0.01), # 1% in 1990 (1995, 0.04), # 4% in 1995 (2000, 0.11), # 11% in 2000 (2005, 0.19), # 19% in 2005 (survey data) (2008, 0.22), # 22% in 2008 (survey data) (2010, 0.18), # 18% in 2010 (estimated) (2013, 0.21), # 21% in 2013 (survey data) (2015, 0.19), # 19% in 2015 (estimated) (2018, 0.25), # 25% in 2018 (survey data) ] # Define target years and prevalence levels for adults 15-24 (simplified non-decreasing trend) self.hiv_targets_15to24 = [ (1990, 0.01), # 1% in 1990 (1995, 0.05), # 5% in 1995 (2000, 0.10), # 10% in 2000 (2005, 0.10), # 10% in 2005 (leveled off) (2010, 0.10), # 10% in 2010 (leveled off) (2015, 0.10), # 10% in 2015 (leveled off) ]
[docs] def step(self): t = self.sim.now if t < self.pars.start or t > self.pars.stop: return # Get current year current_year = t.year # Find the target prevalence for adults 25+ target_prevalence_25plus = 0.0 for year, prev in self.hiv_targets_25plus: if current_year >= year: target_prevalence_25plus = prev # Find the target prevalence for adults 15-24 target_prevalence_15to24 = 0.0 for year, prev in self.hiv_targets_15to24: if current_year >= year: target_prevalence_15to24 = prev # Apply the target prevalence for both age groups self._apply_prevalence(target_prevalence_25plus, min_age=25, max_age=60) self._apply_prevalence(target_prevalence_15to24, min_age=15, max_age=24)
def _apply_prevalence(self, target_prevalence, min_age=25, max_age=60): """Apply the target HIV prevalence for a specific age range""" self.hiv = self.sim.diseases.hiv people = self.sim.people # Get alive people in target age range alive_mask = people.alive age_mask = (people.age >= min_age) & (people.age <= max_age) eligible_mask = alive_mask & age_mask eligible_uids = people.auids[eligible_mask] if len(eligible_uids) == 0: return # Calculate target number of HIV-positive people target_infectious = int(np.round(len(eligible_uids) * target_prevalence)) # Get current HIV-positive people in eligible age range # First get HIV states for eligible people eligible_hiv_states = self.hiv.state[eligible_uids] hiv_positive_mask = np.isin(eligible_hiv_states, [HIVState.ACUTE, HIVState.LATENT, HIVState.AIDS]) current_infectious_uids = eligible_uids[hiv_positive_mask] n_current = len(current_infectious_uids) delta = target_infectious - n_current if delta > 0: # Need to add more HIV infections at_risk_mask = (eligible_hiv_states == HIVState.ATRISK) at_risk_uids = eligible_uids[at_risk_mask] if delta > len(at_risk_uids): # Not enough eligible people to infect delta = len(at_risk_uids) if delta > 0: # Randomly select people to infect chosen_indices = np.random.choice(len(at_risk_uids), size=delta, replace=False) chosen_uids = at_risk_uids[chosen_indices] self.hiv.state[chosen_uids] = HIVState.ACUTE # Put some of them on ART art_indices = np.random.choice(len(chosen_uids), size=int(len(chosen_uids) * self.pars.percent_on_ART), replace=False) art_uids = chosen_uids[art_indices] self.hiv.on_ART[art_uids] = True
from mpl_toolkits.axes_grid1.inset_locator import inset_axes import datetime import time import sys import os # Dynamically add the correct path to scripts for common_functions import try: current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: # __file__ is not defined (e.g., in Jupyter), use cwd current_dir = os.getcwd() scripts_path = os.path.abspath(os.path.join(current_dir, '../../scripts')) if scripts_path not in sys.path: sys.path.insert(0, scripts_path) # Also add the current directory to the path for local imports if current_dir not in sys.path: sys.path.insert(0, current_dir) # Add the parent directory to the path for data access parent_dir = os.path.abspath(os.path.join(current_dir, '..')) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) import common_functions as cf # Import health-seeking, diagnostic, and treatment interventions from tbsim.interventions.tb_health_seeking import HealthSeekingBehavior from tbsim.interventions.tb_diagnostic import TBDiagnostic from tbsim.interventions.tb_treatment import TBTreatment
[docs] class AgeDependentTBProgression(ss.Intervention): """ Intervention to modify TB progression rates and fast progressor fractions based on age groups. This intervention adjusts: 1. The relative risk (rr_activation) multiplier for TB progression from latent to active disease 2. The probability of becoming a fast progressor (p_latent_fast) after TB infection Age-specific multipliers: - 0-4 years: 2.0x the base rate (higher progression and fast progressor fraction) - 5-14 years: 0.5x the base rate (lower progression and fast progressor fraction) - 15+ years: 1.0x the base rate (base progression and fast progressor fraction) """
[docs] def __init__(self, pars, **kwargs): super().__init__(**kwargs) self.define_pars( age_0_4_multiplier=2.0, # 2x progression for 0-4 year olds age_5_14_multiplier=0.5, # 0.5x progression for 5-14 year olds age_15plus_multiplier=1.0, # 1x progression for 15+ year olds ) self.update_pars(pars, **kwargs)
[docs] def step(self): """Apply age-dependent TB progression multipliers and fast progressor fractions""" tb = self.sim.diseases['tb'] people = self.sim.people # Get ages of TB-infected individuals uids_tb = tb.infected.uids ages = people.age[uids_tb] # Apply age-specific multipliers to rr_activation # 0-4 years: 2x progression mask_0_4 = (ages >= 0) & (ages <= 4) tb.rr_activation[uids_tb[mask_0_4]] *= self.pars.age_0_4_multiplier # 5-14 years: 0.5x progression mask_5_14 = (ages >= 5) & (ages <= 14) tb.rr_activation[uids_tb[mask_5_14]] *= self.pars.age_5_14_multiplier # 15+ years: 1x progression (base rate) mask_15plus = ages >= 15 tb.rr_activation[uids_tb[mask_15plus]] *= self.pars.age_15plus_multiplier # Also modify the p_latent_fast parameter for new infections # This affects the probability of becoming a fast progressor vs slow progressor # We need to modify the underlying bernoulli distribution parameters # Get the base p_latent_fast value (typically 0.1) base_p_latent_fast = 0.1 # This is the default value from the TB model # Create age-specific p_latent_fast values p_latent_fast_0_4 = base_p_latent_fast * self.pars.age_0_4_multiplier # 0.2 (20%) p_latent_fast_5_14 = base_p_latent_fast * self.pars.age_5_14_multiplier # 0.05 (5%) p_latent_fast_15plus = base_p_latent_fast * self.pars.age_15plus_multiplier # 0.1 (10%) # Store these values for use in the TB model's infection logic # Note: This is a simplified approach - in a more complex implementation, # we would need to modify the TB model's infection method directly self.age_specific_p_latent_fast = { '0_4': p_latent_fast_0_4, '5_14': p_latent_fast_5_14, '15plus': p_latent_fast_15plus } # Override the TB model's p_latent_fast parameter with age-specific values # This is a more direct approach to ensure the age-specific values are used if hasattr(tb, 'p_latent_fast'): # Store the original p_latent_fast for reference if not hasattr(self, 'original_p_latent_fast'): self.original_p_latent_fast = tb.p_latent_fast # Create age-specific bernoulli distributions tb.p_latent_fast_0_4 = ss.bernoulli(p=p_latent_fast_0_4) tb.p_latent_fast_5_14 = ss.bernoulli(p=p_latent_fast_5_14) tb.p_latent_fast_15plus = ss.bernoulli(p=p_latent_fast_15plus) # Override the TB model's infection method to use age-specific p_latent_fast if not hasattr(self, 'original_infect'): self.original_infect = tb.infect def age_dependent_infect(tb_self, uids, hosp=None, hosp_max=None, source_uids=None, **kwargs): """Override the infect method to use age-specific p_latent_fast values""" # Call the original infect method first result = self.original_infect(tb_self, uids, hosp, hosp_max, source_uids, **kwargs) # Now modify the latent state assignment based on age people = self.sim.people ages = people.age[uids] # Get the newly infected individuals (those who just became latent) newly_infected = uids[tb_self.state[uids] == mtb.TBS.LATENT_SLOW] newly_infected = np.append(newly_infected, uids[tb_self.state[uids] == mtb.TBS.LATENT_FAST]) if len(newly_infected) > 0: newly_infected_ages = people.age[newly_infected] # Apply age-specific fast progressor probabilities # 0-4 years: 20% fast progressors mask_0_4 = (newly_infected_ages >= 0) & (newly_infected_ages <= 4) fast_0_4 = tb_self.p_latent_fast_0_4.filter(newly_infected[mask_0_4]) tb_self.state[newly_infected[mask_0_4]] = np.where(fast_0_4, mtb.TBS.LATENT_FAST, mtb.TBS.LATENT_SLOW) # 5-14 years: 5% fast progressors mask_5_14 = (newly_infected_ages >= 5) & (newly_infected_ages <= 14) fast_5_14 = tb_self.p_latent_fast_5_14.filter(newly_infected[mask_5_14]) tb_self.state[newly_infected[mask_5_14]] = np.where(fast_5_14, mtb.TBS.LATENT_FAST, mtb.TBS.LATENT_SLOW) # 15+ years: 10% fast progressors (base rate) mask_15plus = newly_infected_ages >= 15 fast_15plus = tb_self.p_latent_fast_15plus.filter(newly_infected[mask_15plus]) tb_self.state[newly_infected[mask_15plus]] = np.where(fast_15plus, mtb.TBS.LATENT_FAST, mtb.TBS.LATENT_SLOW) return result # Replace the TB model's infect method tb.infect = age_dependent_infect.__get__(tb, type(tb))
[docs] def make_people(n_agents, age_data=None): if age_data is None: # Use South Africa 1960 age structure instead of Vietnam age_data = pd.DataFrame({ 'age': np.arange(0, 101, 5), 'value': [12000, 10000, 8500, 7500, 6500, 5500, 4500, 3500, 2500, 2000, 1500, 1200, 800, 500, 300, 150, 80, 40, 15, 5, 1] # South Africa 1960 approximate }) # Create population with extra states required for health-seeking and diagnostic interventions people = ss.People(n_agents=n_agents, age_data=age_data, extra_states=mtb.get_extrastates()) return people
[docs] def compute_latent_prevalence(sim): # Get latent counts latent_slow = sim.results['tb']['n_latent_slow'] latent_fast = sim.results['tb']['n_latent_fast'] latent_total = latent_slow + latent_fast # Try getting time-aligned n_alive from starsim if available try: n_alive_series = sim.results['n_alive'] except KeyError: # Fallback: use average population size n_alive_series = np.full_like(latent_total, fill_value=np.count_nonzero(sim.people.alive)) return latent_total / n_alive_series
[docs] def debug_hiv_results(sim): """Debug function to print available HIV result keys and values""" print("=== HIV Results Debug ===") try: print(f"Available HIV result keys: {list(sim.results['hiv'].keys())}") print(f"HIV result values at final timestep:") for key, value in sim.results['hiv'].items(): if hasattr(value, '__len__') and len(value) > 0: print(f" {key}: {value[-1]}") else: print(f" {key}: {value}") except Exception as e: print(f"Error accessing HIV results: {e}") print("========================")
[docs] def compute_hiv_prevalence(sim): """Compute HIV prevalence over time""" try: # Try to get HIV prevalence directly - use the correct key from HIV model hiv_prev = sim.results['hiv']['hiv_prevalence'] return hiv_prev except (KeyError, AttributeError): try: # Fallback: compute from HIV infection counts using correct key n_hiv = sim.results['hiv']['infected'] n_alive = sim.results['n_alive'] return n_hiv / n_alive except (KeyError, AttributeError): try: # Another fallback: use n_active from HIV model n_hiv = sim.results['hiv']['n_active'] n_alive = sim.results['n_alive'] return n_hiv / n_alive except (KeyError, AttributeError): # Debug: print available HIV result keys try: print(f"Available HIV result keys: {list(sim.results['hiv'].keys())}") except: print("HIV results not found in simulation") # If HIV results are not available, return zeros time_length = len(sim.results['timevec']) return np.zeros(time_length)
[docs] def compute_hiv_prevalence_adults_25plus(sim, target_year=None): """ Compute HIV prevalence for adults 25+ at a specific year or over time Args: sim: Simulation object target_year: If specified, compute for this year only. If None, compute over time. Returns: If target_year specified: float (prevalence for that year) If target_year is None: array (prevalence over time) """ try: # Get people alive and HIV states people = sim.people alive_mask = people.alive hiv_states = sim.diseases.hiv.state # Get HIV-positive states (states 1, 2, 3 are positive) hiv_positive_mask = np.isin(hiv_states, [1, 2, 3]) # Filter for adults 25+ adult_25plus_mask = (people.age >= 25) # Combine masks alive_adult_25plus_mask = alive_mask & adult_25plus_mask if target_year is not None: # Compute for specific year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get states at target time (this is approximate - we use current states) total_adults_25plus = np.sum(alive_adult_25plus_mask) hiv_positive_adults_25plus = np.sum(alive_adult_25plus_mask & hiv_positive_mask) if total_adults_25plus > 0: return hiv_positive_adults_25plus / total_adults_25plus else: return 0.0 else: # Compute over time (this is approximate since we only have current states) # For now, return the overall HIV prevalence as a proxy return compute_hiv_prevalence(sim) except Exception as e: print(f"Error computing HIV prevalence for adults 25+: {e}") if target_year is not None: return 0.0 else: time_length = len(sim.results['timevec']) return np.zeros(time_length)
[docs] def compute_hiv_prevalence_adults_15to24(sim, target_year=None): """ Compute HIV prevalence for adults 15-24 at a specific year or over time Args: sim: Simulation object target_year: If specified, compute for this year only. If None, compute over time. Returns: If target_year specified: float (prevalence for that year) If target_year is None: array (prevalence over time) """ try: # Get people alive and HIV states people = sim.people alive_mask = people.alive hiv_states = sim.diseases.hiv.state # Get HIV-positive states (states 1, 2, 3 are positive) hiv_positive_mask = np.isin(hiv_states, [1, 2, 3]) # Filter for adults 15-24 adult_15to24_mask = (people.age >= 15) & (people.age <= 24) # Combine masks alive_adult_15to24_mask = alive_mask & adult_15to24_mask if target_year is not None: # Compute for specific year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get states at target time (this is approximate - we use current states) total_adults_15to24 = np.sum(alive_adult_15to24_mask) hiv_positive_adults_15to24 = np.sum(alive_adult_15to24_mask & hiv_positive_mask) if total_adults_15to24 > 0: return hiv_positive_adults_15to24 / total_adults_15to24 else: return 0.0 else: # Compute over time (this is approximate since we only have current states) # For now, return the overall HIV prevalence as a proxy return compute_hiv_prevalence(sim) except Exception as e: print(f"Error computing HIV prevalence for adults 15-24: {e}") if target_year is not None: return 0.0 else: time_length = len(sim.results['timevec']) return np.zeros(time_length)
[docs] def compute_hiv_positive_tb_prevalence(sim): """Compute HIV-positive TB prevalence as proportion of total population""" try: # Try to get HIV-positive TB counts directly hiv_positive_tb = sim.results['tb']['n_active_hiv_positive'] n_alive = sim.results['n_alive'] return hiv_positive_tb / n_alive except (KeyError, AttributeError): try: # Fallback: compute from individual states # Get HIV and TB prevalence using correct keys hiv_prev = compute_hiv_prevalence(sim) tb_prev = sim.results['tb']['prevalence_active'] # Estimate HIV-positive TB prevalence as a fraction of total TB # This is a rough estimate - in reality it depends on the TB-HIV interaction hiv_tb_overlap = hiv_prev * tb_prev * 0.3 # Assume 30% of TB cases are HIV-positive return hiv_tb_overlap except (KeyError, AttributeError): # If results are not available, return zeros time_length = len(sim.results['timevec']) return np.zeros(time_length)
[docs] def compute_age_stratified_prevalence(sim, target_year=2018): """ Compute age-stratified TB prevalence from simulation results Args: sim: Simulation object target_year: Year to compute prevalence for Returns: dict: Age-stratified prevalence data """ # Find the time index closest to target year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get people alive at target time people = sim.people alive_mask = people.alive # Get TB states tb_states = sim.diseases.tb.state active_tb_mask = np.isin(tb_states, [mtb.TBS.ACTIVE_SMPOS, mtb.TBS.ACTIVE_SMNEG, mtb.TBS.ACTIVE_EXPTB]) # Get ages at target time ages = people.age[alive_mask] active_tb_ages = people.age[alive_mask & active_tb_mask] # Define age groups including children and adolescents age_groups = [(0, 4), (5, 14), (15, 24), (25, 34), (35, 44), (45, 54), (55, 64), (65, 200)] age_group_labels = ['0-4', '5-14', '15-24', '25-34', '35-44', '45-54', '55-64', '65+'] prevalence_by_age = {} for i, (min_age, max_age) in enumerate(age_groups): # Count people in age group age_mask = (ages >= min_age) & (ages <= max_age) total_in_age_group = np.sum(age_mask) # Count active TB cases in age group age_tb_mask = (active_tb_ages >= min_age) & (active_tb_ages <= max_age) tb_in_age_group = np.sum(age_tb_mask) # Calculate prevalence if total_in_age_group > 0: prevalence = tb_in_age_group / total_in_age_group prevalence_per_100k = prevalence * 100000 else: prevalence = 0 prevalence_per_100k = 0 prevalence_by_age[age_group_labels[i]] = { 'prevalence': prevalence, 'prevalence_per_100k': prevalence_per_100k, 'total_people': total_in_age_group, 'tb_cases': tb_in_age_group } return prevalence_by_age
[docs] def compute_age_stratified_prevalence_time_series(sim): """ Compute age-stratified TB prevalence time series from simulation results Args: sim: Simulation object Returns: pd.DataFrame: DataFrame with years as index and age groups as columns """ # Define age groups including children and adolescents age_groups = [(0, 4), (5, 14), (15, 24), (25, 34), (35, 44), (45, 54), (55, 64), (65, 200)] age_group_labels = ['0-4', '5-14', '15-24', '25-34', '35-44', '45-54', '55-64', '65+'] # Get time vector time_years = np.array([d.year for d in sim.results['timevec']]) # Initialize DataFrame to store results prevalence_df = pd.DataFrame(index=time_years, columns=age_group_labels) # For each time point, compute age-stratified prevalence for t_idx, (time_point, year) in enumerate(zip(sim.results['timevec'], time_years)): # Get people alive at this time point people = sim.people # For simplicity, we'll use the current people state # In a more sophisticated approach, we'd need to track historical states alive_mask = people.alive # Get TB states tb_states = sim.diseases.tb.state active_tb_mask = np.isin(tb_states, [mtb.TBS.ACTIVE_SMPOS, mtb.TBS.ACTIVE_SMNEG, mtb.TBS.ACTIVE_EXPTB]) # Get ages ages = people.age[alive_mask] active_tb_ages = people.age[alive_mask & active_tb_mask] # Compute prevalence for each age group for i, (min_age, max_age) in enumerate(age_groups): # Count people in age group age_mask = (ages >= min_age) & (ages <= max_age) total_in_age_group = np.sum(age_mask) # Count active TB cases in age group age_tb_mask = (active_tb_ages >= min_age) & (active_tb_ages <= max_age) tb_in_age_group = np.sum(age_tb_mask) # Calculate prevalence per 100,000 if total_in_age_group > 0: prevalence = tb_in_age_group / total_in_age_group prevalence_per_100k = prevalence * 100000 else: prevalence_per_100k = 0 prevalence_df.loc[year, age_group_labels[i]] = prevalence_per_100k return prevalence_df
[docs] def plot_total_population_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): import matplotlib.ticker as mtick nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Use datetime objects directly n_alive = sim.results['n_alive'] ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, n_alive, color='blue', label='Total Population') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Population Size') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Simulated Total Population', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"total_population_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_hiv_metrics_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot HIV prevalence for both age groups (15-24 and 25+) with target data points""" import matplotlib.ticker as mtick # Define target data points for adults 25+ from van Schalkwyk et al. 2021 for eThekwini, South Africa estimated_data_25plus = [ (1990, 0.01), # 1% in 1990 (1995, 0.04), # 4% in 1995 (2000, 0.11), # 11% in 2000 (2010, 0.18), # 18% in 2010 (2015, 0.19), # 19% in 2015 ] survey_data_25plus = [ (2005, 0.19), # 19% in 2005 (2008, 0.22), # 22% in 2008 (2013, 0.21), # 21% in 2013 (2018, 0.25), # 25% in 2018 ] # Define target data points for adults 15-24 (simplified non-decreasing trend) estimated_data_15to24 = [ (1990, 0.01), # 1% in 1990 (1995, 0.05), # 5% in 1995 (2000, 0.10), # 10% in 2000 (2005, 0.10), # 10% in 2005 (2010, 0.10), # 10% in 2010 (2015, 0.10), # 10% in 2015 ] nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 6 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Debug: print HIV results for the first simulation if m == 0 and i == 0 and j == 0: debug_hiv_results(sim) # Compute HIV prevalence for both age groups hiv_prev_25plus = compute_hiv_prevalence_adults_25plus(sim) hiv_prev_15to24 = compute_hiv_prevalence_adults_15to24(sim) ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot HIV prevalence for adults 25+ ax.plot(time, hiv_prev_25plus, label='Model HIV Prevalence (25+)', color='blue', linewidth=2) # Plot estimated data points for 25+ for year, prev in estimated_data_25plus: ax.plot(datetime.date(year, 1, 1), prev, 'go', markersize=4, alpha=0.8, label='Estimated Data (25+)' if year == 1990 else "") # Plot survey data points for 25+ for year, prev in survey_data_25plus: ax.plot(datetime.date(year, 1, 1), prev, 'ro', markersize=4, alpha=0.8, label='Survey Data (25+)' if year == 2005 else "") # Plot HIV prevalence for adults 15-24 ax.plot(time, hiv_prev_15to24, label='Model HIV Prevalence (15-24)', color='orange', linewidth=2, linestyle='--') # Plot estimated data points for 15-24 for year, prev in estimated_data_15to24: ax.plot(datetime.date(year, 1, 1), prev, 'mo', markersize=4, alpha=0.8, label='Estimated Data (15-24)' if year == 1990 else "") ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True, alpha=0.3) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('HIV Prevalence') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) # Set y-axis to show percentages properly ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0)) plt.tight_layout() plt.suptitle('HIV Prevalence by Age Group: Model vs van Schalkwyk et al. 2021 Data (eThekwini, South Africa)', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"hiv_metrics_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_active_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot active TB prevalence for all parameter combinations with separate focus on active TB""" import matplotlib.pyplot as plt nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] active_prev = sim.results['tb']['prevalence_active'] ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, active_prev, label='Active TB Prevalence', color='blue', linewidth=2) ax.axhline(0.01, color='red', linestyle=':', linewidth=1, label='Target 1%') # Plot the 2018 SA data point (real data) ax.plot(datetime.date(2018, 1, 1), 0.00852, 'ro', markersize=6, label='2018 SA data (0.852%)') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Active TB Prevalence') ax.grid(True) if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Active TB Prevalence Sweep', fontsize=16, y=1.02) filename = f"active_tb_prevalence_sweep_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_latent_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot latent TB prevalence for all parameter combinations with separate focus on latent TB""" import matplotlib.pyplot as plt nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] latent_prev = compute_latent_prevalence(sim) ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, latent_prev, label='Latent TB Prevalence', color='orange', linewidth=2, linestyle='--') ax.axhline(0.5, color='red', linestyle=':', linewidth=1, label='Target 50%') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Latent TB Prevalence') ax.grid(True) if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Latent TB Prevalence Sweep', fontsize=16, y=1.02) filename = f"latent_tb_prevalence_sweep_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Legacy function - now calls both separate active and latent TB plots""" plot_active_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_latent_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp)
[docs] def compute_annualized_infection_rate(sim): """ Compute annualized TB infection rate (annual risk of infection) over time. This function calculates the annualized infection rate using two methods: 1. Method 1: Sum new_infections over 365 days and divide by population 2. Method 2: Difference in n_infected between T and T-365 days, divided by population Returns the annualized infection rate as a percentage of the population. """ time = sim.results['timevec'] tb_results = sim.results['tb'] # Get population size over time try: n_alive = sim.results['n_alive'] except KeyError: n_alive = np.full(len(time), fill_value=np.count_nonzero(sim.people.alive)) # Method 1: Using new_infections (if available) annual_rate_method1 = None try: # Check if new_infections is available if 'new_infections' in tb_results: new_infections = tb_results['new_infections'].values annual_rate_method1 = np.zeros_like(time, dtype=float) # Calculate 365-day rolling sum of new infections days_per_step = (time[1] - time[0]).days if len(time) > 1 else 1 steps_per_year = max(1, int(365 / days_per_step)) for i in range(len(time)): start_idx = max(0, i - steps_per_year + 1) annual_infections = np.sum(new_infections[start_idx:i+1]) annual_rate_method1[i] = (annual_infections / n_alive[i]) * 100 if n_alive[i] > 0 else 0 except Exception as e: print(f"Method 1 failed: {e}") # Method 2: Using difference in n_infected annual_rate_method2 = np.zeros_like(time, dtype=float) try: # Get total infected count over time n_infected = tb_results['n_latent_slow'].values + tb_results['n_latent_fast'].values + tb_results['n_active'].values # Calculate 365-day difference days_per_step = (time[1] - time[0]).days if len(time) > 1 else 1 steps_per_year = max(1, int(365 / days_per_step)) for i in range(len(time)): if i >= steps_per_year: # Calculate difference in infected count over the year infection_diff = n_infected[i] - n_infected[i - steps_per_year] annual_rate_method2[i] = (infection_diff / n_alive[i]) * 100 if n_alive[i] > 0 else 0 else: # For early time points, use the current rate scaled to annual annual_rate_method2[i] = (n_infected[i] / n_alive[i]) * 100 if n_alive[i] > 0 else 0 except Exception as e: print(f"Method 2 failed: {e}") # Return the more robust method (Method 2) or Method 1 if Method 2 fails if annual_rate_method2 is not None and not np.all(np.isnan(annual_rate_method2)): return annual_rate_method2 elif annual_rate_method1 is not None and not np.all(np.isnan(annual_rate_method1)): return annual_rate_method1 else: print("Warning: Could not compute annualized infection rate") return np.zeros_like(time, dtype=float)
[docs] def plot_annualized_infection_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot annualized TB infection rate for all parameter combinations""" import matplotlib.pyplot as plt nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] annual_rate = compute_annualized_infection_rate(sim) ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, annual_rate, label='Annual Infection Rate', color='purple', linewidth=2) ax.axhline(2.0, color='red', linestyle=':', linewidth=1, label='2% Annual Risk') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Annual Infection Rate (%)') ax.grid(True) if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Annualized TB Infection Rate', fontsize=16, y=1.02) filename = f"annualized_infection_rate_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_health_seeking_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot health-seeking behavior metrics for all parameter combinations Metrics plotted: - new_sought_care: Number of people who sought care in this timestep (count) - n_sought_care: Cumulative number of people who have ever sought care (count) - n_eligible: Number of people with active TB eligible for care-seeking (count) """ import matplotlib.ticker as mtick nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Get health-seeking results hsb = sim.results['healthseekingbehavior'] new_sought = hsb['new_sought_care'].values n_sought = hsb['n_sought_care'].values n_eligible = hsb['n_eligible'].values ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot new people seeking care each step ax.plot(time, new_sought, label='New Sought Care', color='green', linewidth=2) # Plot cumulative people who sought care ax.plot(time, n_sought, label='Cumulative Sought Care', color='blue', linestyle='--') # Plot eligible population ax.plot(time, n_eligible, label='Eligible (Active TB)', color='red', linestyle=':') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Number of People') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Health-Seeking Behavior Over Time', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"health_seeking_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_diagnostic_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot diagnostic testing metrics for all parameter combinations Metrics plotted: - n_tested: Number of people tested in this timestep (count) - n_test_positive: Number of positive test results in this timestep (count) - n_test_negative: Number of negative test results in this timestep (count) """ import matplotlib.ticker as mtick nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Get diagnostic results tbdiag = sim.results['tbdiagnostic'] n_tested = tbdiag['n_tested'].values n_test_positive = tbdiag['n_test_positive'].values n_test_negative = tbdiag['n_test_negative'].values cum_test_positive = tbdiag['cum_test_positive'].values cum_test_negative = tbdiag['cum_test_negative'].values ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot incident testing results ax.plot(time, n_tested, label='Tested', color='blue', marker='o', markersize=2) ax.plot(time, n_test_positive, label='Tested Positive', color='green', linestyle='--') ax.plot(time, n_test_negative, label='Tested Negative', color='red', linestyle=':') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Number of People') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('TB Diagnostic Testing Outcomes', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"diagnostic_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_cumulative_diagnostic_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot cumulative diagnostic testing results for all parameter combinations Metrics plotted: - cum_test_positive: Cumulative number of positive test results over time (count) - cum_test_negative: Cumulative number of negative test results over time (count) """ import matplotlib.ticker as mtick nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Get cumulative diagnostic results tbdiag = sim.results['tbdiagnostic'] cum_test_positive = tbdiag['cum_test_positive'].values cum_test_negative = tbdiag['cum_test_negative'].values ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot cumulative testing results ax.plot(time, cum_test_positive, label='Cumulative Positives', color='green', linestyle='--') ax.plot(time, cum_test_negative, label='Cumulative Negatives', color='red', linestyle=':') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Cumulative Tests') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Cumulative TB Diagnostic Results', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"cumulative_diagnostic_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_treatment_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot TB treatment outcomes for all parameter combinations Metrics plotted: - n_treated: Number of people who started treatment in this timestep (count) - n_treatment_success: Number of successful treatment completions in this timestep (count) - n_treatment_failure: Number of failed treatment attempts in this timestep (count) """ import matplotlib.ticker as mticker nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Get treatment results tbtx = sim.results['tbtreatment'] n_treated = tbtx['n_treated'].values n_treatment_success = tbtx['n_treatment_success'].values n_treatment_failure = tbtx['n_treatment_failure'].values ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot treatment outcomes ax.plot(time, n_treated, label='Treated', color='blue', marker='o', markersize=2) ax.plot(time, n_treatment_success, label='Successes', color='green', linestyle='--') ax.plot(time, n_treatment_failure, label='Failures', color='red', linestyle=':') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Number of People') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('TB Treatment Outcomes', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"treatment_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_cumulative_treatment_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot cumulative TB treatment outcomes for all parameter combinations Metrics plotted: - cum_treatment_success: Cumulative number of successful treatments over time (count) - cum_treatment_failure: Cumulative number of failed treatments over time (count) """ import matplotlib.ticker as mticker nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = sim.results['timevec'] # Get cumulative treatment results tbtx = sim.results['tbtreatment'] cum_treatment_success = tbtx['cum_treatment_success'].values cum_treatment_failure = tbtx['cum_treatment_failure'].values ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Plot cumulative treatment outcomes ax.plot(time, cum_treatment_success, label='Cumulative Successes', color='green', linestyle='--') ax.plot(time, cum_treatment_failure, label='Cumulative Failures', color='red', linestyle=':') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Cumulative Treatments') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) plt.tight_layout() plt.suptitle('Cumulative TB Treatment Outcomes', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots (same as refined TB prevalence plot) first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"cumulative_treatment_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_age_prevalence_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot age-stratified TB prevalence for all parameter combinations This function creates a grid of plots showing age-stratified TB prevalence rates by age groups including children (0-4, 5-14) and adults (15+), normalized per 100,000 population. The data is compared to the 2018 South Africa prevalence survey data where available. """ import matplotlib.ticker as mtick # 2018 South Africa survey data (per 100,000 population) - only available for 15+ sa_2018_data = { '15-24': 432, '25-34': 902, '35-44': 1107, '45-54': 1063, '55-64': 845, '65+': 1104 } # All age groups including children all_age_groups = ['0-4', '5-14', '15-24', '25-34', '35-44', '45-54', '55-64', '65+'] # Create extended data array with NaN for age groups without survey data sa_2018_values = [] for group in all_age_groups: if group in sa_2018_data: sa_2018_values.append(sa_2018_data[group]) else: sa_2018_values.append(np.nan) # No data available for children nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] # Compute age-stratified prevalence for 2018 age_prevalence = compute_age_stratified_prevalence(sim, target_year=2018) model_prevalence = [age_prevalence[group]['prevalence_per_100k'] for group in all_age_groups] ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Create bar plot x_pos = np.arange(len(all_age_groups)) width = 0.35 # Plot model results bars1 = ax.bar(x_pos - width/2, model_prevalence, width, label='Model (2018)', alpha=0.8, color='blue') # Plot South Africa 2018 data (only for age groups with data) valid_data_mask = ~np.isnan(sa_2018_values) bars2 = ax.bar(x_pos[valid_data_mask] + width/2, [sa_2018_values[i] for i in range(len(sa_2018_values)) if valid_data_mask[i]], width, label='SA Data (2018)', alpha=0.8, color='red') # Add value labels on bars for bar in bars1: height = bar.get_height() if height > 0: ax.text(bar.get_x() + bar.get_width()/2., height + 50, f'{height:.0f}', ha='center', va='bottom', fontsize=8) for bar in bars2: height = bar.get_height() if height > 0: ax.text(bar.get_x() + bar.get_width()/2., height + 50, f'{height:.0f}', ha='center', va='bottom', fontsize=8) ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.set_xlabel('Age Group') ax.set_ylabel('TB Prevalence (per 100,000)') ax.set_xticks(x_pos) ax.set_xticklabels(all_age_groups, rotation=45) ax.grid(True, alpha=0.3) if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) # Add percentage differences (only for age groups with survey data) for k, (model_val, data_val) in enumerate(zip(model_prevalence, sa_2018_values)): if not np.isnan(data_val) and data_val > 0: pct_diff = ((model_val - data_val) / data_val) * 100 ax.annotate(f'{pct_diff:.1f}%', xy=(k, max(model_val, data_val) + 100), xytext=(0, 5), textcoords='offset points', ha='center', fontsize=7, color='darkgreen') plt.tight_layout() plt.suptitle('Age-Stratified TB Prevalence: Model vs South Africa 2018 Survey Data', fontsize=14, y=1.02) filename = f"age_prevalence_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def compute_hiv_tb_coinfection_rates(sim, target_year=2018): """ Compute HIV coinfection rates among TB cases by symptom status Args: sim: Simulation object target_year: Year to compute rates for Returns: dict: HIV coinfection rates by TB symptom status """ # Find the time index closest to target year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get people alive at target time people = sim.people alive_mask = people.alive # Get TB states tb_states = sim.diseases.tb.state hiv_states = sim.diseases.hiv.state # Define TB states by symptom status # Presymptomatic (0 symptoms) - ACTIVE_PRESYMP presymptomatic_mask = (tb_states == mtb.TBS.ACTIVE_PRESYMP) # Symptomatic (≥1 symptoms) - ACTIVE_SMPOS, ACTIVE_SMNEG, ACTIVE_EXPTB symptomatic_mask = np.isin(tb_states, [mtb.TBS.ACTIVE_SMPOS, mtb.TBS.ACTIVE_SMNEG, mtb.TBS.ACTIVE_EXPTB]) # All active TB (any symptoms) all_active_mask = np.isin(tb_states, [mtb.TBS.ACTIVE_PRESYMP, mtb.TBS.ACTIVE_SMPOS, mtb.TBS.ACTIVE_SMNEG, mtb.TBS.ACTIVE_EXPTB]) # Get HIV-positive states (assuming HIV states 1, 2, 3 are positive - adjust as needed) # HIV states: 0=uninfected, 1=acute, 2=latent, 3=AIDS hiv_positive_mask = np.isin(hiv_states, [1, 2, 3]) # Filter for adults (age 15+) adult_mask = (people.age >= 15) # Combine masks alive_adult_mask = alive_mask & adult_mask # Calculate coinfection rates for each category coinfection_rates = {} # 1. Presymptomatic TB cases (0 symptoms) presymp_adult_mask = alive_adult_mask & presymptomatic_mask presymp_total = np.sum(presymp_adult_mask) presymp_hiv_positive = np.sum(presymp_adult_mask & hiv_positive_mask) presymp_hiv_rate = (presymp_hiv_positive / presymp_total * 100) if presymp_total > 0 else 0 coinfection_rates['presymptomatic'] = { 'total_cases': presymp_total, 'hiv_positive': presymp_hiv_positive, 'hiv_rate_percent': presymp_hiv_rate } # 2. Symptomatic TB cases (≥1 symptoms) sympt_adult_mask = alive_adult_mask & symptomatic_mask sympt_total = np.sum(sympt_adult_mask) sympt_hiv_positive = np.sum(sympt_adult_mask & hiv_positive_mask) sympt_hiv_rate = (sympt_hiv_positive / sympt_total * 100) if sympt_total > 0 else 0 coinfection_rates['symptomatic'] = { 'total_cases': sympt_total, 'hiv_positive': sympt_hiv_positive, 'hiv_rate_percent': sympt_hiv_rate } # 3. All active TB cases (any symptoms) all_active_adult_mask = alive_adult_mask & all_active_mask all_active_total = np.sum(all_active_adult_mask) all_active_hiv_positive = np.sum(all_active_adult_mask & hiv_positive_mask) all_active_hiv_rate = (all_active_hiv_positive / all_active_total * 100) if all_active_total > 0 else 0 coinfection_rates['all_active'] = { 'total_cases': all_active_total, 'hiv_positive': all_active_hiv_positive, 'hiv_rate_percent': all_active_hiv_rate } return coinfection_rates
[docs] def plot_hiv_tb_coinfection_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot HIV coinfection rates among TB cases by symptom status for all parameter combinations This function creates a grid of plots showing HIV coinfection rates among TB cases stratified by symptom status, comparing model results to 2018 South Africa survey data. """ import matplotlib.ticker as mtick # 2018 South Africa survey data (HIV coinfection rates by symptom status) sa_2018_data = { 'presymptomatic': 22.4, # 0 symptoms (presymptomatic) 'symptomatic': 36.9, # ≥1 symptoms (symptomatic) - calculated from weighted average 'all_active': 28.8 # All active TB cases } categories = ['presymptomatic', 'symptomatic', 'all_active'] category_labels = ['0 Symptoms\n(Presymptomatic)', '≥1 Symptoms\n(Symptomatic)', 'All Active TB'] sa_2018_values = [sa_2018_data[cat] for cat in categories] nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] # Compute HIV-TB coinfection rates for 2018 coinfection_rates = compute_hiv_tb_coinfection_rates(sim, target_year=2018) model_rates = [coinfection_rates[cat]['hiv_rate_percent'] for cat in categories] ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Create bar plot x_pos = np.arange(len(categories)) width = 0.35 # Plot model results bars1 = ax.bar(x_pos - width/2, model_rates, width, label='Model (2018)', alpha=0.8, color='blue') # Plot South Africa 2018 data bars2 = ax.bar(x_pos + width/2, sa_2018_values, width, label='SA Data (2018)', alpha=0.8, color='red') # Add value labels on bars for bar in bars1: height = bar.get_height() if height > 0: ax.text(bar.get_x() + bar.get_width()/2., height + 1, f'{height:.1f}%', ha='center', va='bottom', fontsize=8) for bar in bars2: height = bar.get_height() if height > 0: ax.text(bar.get_x() + bar.get_width()/2., height + 1, f'{height:.1f}%', ha='center', va='bottom', fontsize=8) ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.set_xlabel('TB Symptom Status') ax.set_ylabel('HIV Coinfection Rate (%)') ax.set_xticks(x_pos) ax.set_xticklabels(category_labels, rotation=0, ha='center') ax.grid(True, alpha=0.3) # Set y-axis to show percentages properly ax.yaxis.set_major_formatter(mtick.PercentFormatter()) if m == 0 and i == 0 and j == 0: ax.legend(fontsize=6) # Add percentage differences for k, (model_val, data_val) in enumerate(zip(model_rates, sa_2018_values)): if data_val > 0: pct_diff = ((model_val - data_val) / data_val) * 100 ax.annotate(f'{pct_diff:.1f}%', xy=(k, max(model_val, data_val) + 2), xytext=(0, 5), textcoords='offset points', ha='center', fontsize=7, color='darkgreen') # Add case counts as text annotations for k, cat in enumerate(categories): total_cases = coinfection_rates[cat]['total_cases'] hiv_positive = coinfection_rates[cat]['hiv_positive'] ax.text(k, -5, f'n={total_cases}\nHIV+={hiv_positive}', ha='center', va='top', fontsize=6, color='gray') plt.tight_layout() plt.suptitle('HIV Coinfection Rates Among TB Cases by Symptom Status: Model vs South Africa 2018 Survey Data', fontsize=14, y=1.02) filename = f"hiv_tb_coinfection_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_case_notification_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot annualized TB case notification rate (per 100,000) for all parameter combinations in a grid. The notification rate at time t is the difference in cumulative positive diagnoses between t and t-365 days, divided by the population at t, times 100,000. Overlays real South Africa notification data from GTB report. """ import matplotlib.ticker as mtick import os import rdata import pandas as pd # --- Load real notification data (from extract_gtb_data.py logic) --- base_dir = os.path.dirname(os.path.abspath(__file__)) gtb_dir = os.path.join(base_dir, '../tbsim/data/gtbreport2024/data/gtb') snapshot_dir = os.path.join(gtb_dir, 'snapshot_2024-07-29') other_dir = os.path.join(gtb_dir, 'other') tb_rda_path = os.path.join(snapshot_dir, 'tb.rda') pop_rda_path = os.path.join(other_dir, 'pop.rda') # Helper to load RDA file and return as pandas DataFrame def load_rda_df(rda_path): import rdata parsed = rdata.parser.parse_file(rda_path) converted = rdata.conversion.convert(parsed) for v in converted.values(): if isinstance(v, pd.DataFrame): return v raise ValueError(f"No DataFrame found in {rda_path}") tb_df = load_rda_df(tb_rda_path) pop_df = load_rda_df(pop_rda_path) sa_code = 'ZAF' tb_sa = tb_df[tb_df['iso3'] == sa_code] pop_sa = pop_df[pop_df['iso3'] == sa_code] notif_vars = [col for col in tb_sa.columns if 'new' in col and ('bact' in col or 'labconf' in col or 'notif' in col or 'pos' in col)] notif_var = None for v in ['new_bact_pos', 'new_labconf', 'new_notif', 'new_pos']: if v in tb_sa.columns: notif_var = v break if notif_var is None and notif_vars: notif_var = notif_vars[0] if notif_var is None: raise ValueError('No notification variable found in TB data') pop_col = None for c in ['pop', 'e_pop_num', 'population']: if c in pop_sa.columns: pop_col = c break if pop_col is None: raise ValueError('No population column found in population data') merged = pd.merge(tb_sa[['year', notif_var]], pop_sa[['year', pop_col]], on='year', how='inner') merged = merged.sort_values('year') merged['notif_rate_per_100k'] = merged[notif_var] / merged[pop_col] * 1e5 real_years = merged['year'].values real_rates = merged['notif_rate_per_100k'].values # --- Plot model grid --- nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = np.array(sim.results['timevec']) tbdiag = sim.results['tbdiagnostic'] cum_test_positive = tbdiag['cum_test_positive'].values n_alive = sim.results['n_alive'] # Compute annualized notification rate notification_rate = np.zeros_like(cum_test_positive, dtype=float) for t in range(len(time)): t_date = time[t] t_prev_date = t_date - datetime.timedelta(days=365) t_prev = np.searchsorted(time, t_prev_date) if t_prev == len(time) or time[t_prev] > t_prev_date: t_prev = max(0, t_prev - 1) notifications = cum_test_positive[t] - cum_test_positive[t_prev] pop = n_alive[t] notification_rate[t] = (notifications / pop) * 1e5 if pop > 0 else 0 # --- Compute annualized TB incidence rate --- tb_results = sim.results['tb'] if 'cum_active' in tb_results: cum_incidence = tb_results['cum_active'] else: # Fallback: compute cumulative sum of new_active if 'new_active' in tb_results: cum_incidence = np.cumsum(tb_results['new_active']) else: raise ValueError('No new_active or cum_active in tb results') incidence_rate = np.zeros_like(cum_incidence, dtype=float) for t in range(len(time)): t_date = time[t] t_prev_date = t_date - datetime.timedelta(days=365) t_prev = np.searchsorted(time, t_prev_date) if t_prev == len(time) or time[t_prev] > t_prev_date: t_prev = max(0, t_prev - 1) new_cases = cum_incidence[t] - cum_incidence[t_prev] pop = n_alive[t] incidence_rate[t] = (new_cases / pop) * 1e5 if pop > 0 else 0 ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, notification_rate, color='purple', label='Model Notification Rate') ax.plot(time, incidence_rate, color='blue', label='Model Incidence Rate') # Overlay real data ax.plot([datetime.date(int(y), 1, 1) for y in real_years], real_rates, marker='o', color='red', label='SA Notification Data') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('Rate (per 100,000)') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=7) ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.0f')) plt.tight_layout() plt.suptitle('Annualized TB Case Notification Rate (per 100,000)', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"case_notification_rate_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def compute_annualized_tb_mortality_rate(sim): """ Compute annualized TB mortality rate (per 100,000 population) over time. This function calculates the annualized TB mortality rate by: 1. Taking the difference in cumulative TB deaths between time T and T-365 days 2. Dividing by the population at time T 3. Multiplying by 100,000 to get rate per 100,000 population Returns the annualized TB mortality rate per 100,000 population. """ time = sim.results['timevec'] tb_results = sim.results['tb'] # Get population size over time try: n_alive = sim.results['n_alive'] # Handle both numpy arrays and pandas Series if hasattr(n_alive, 'values'): n_alive = n_alive.values except KeyError: n_alive = np.full(len(time), fill_value=np.count_nonzero(sim.people.alive)) # Get cumulative TB deaths if 'cum_deaths' in tb_results: cum_deaths = tb_results['cum_deaths'] # Handle both numpy arrays and pandas Series if hasattr(cum_deaths, 'values'): cum_deaths = cum_deaths.values else: # Fallback: compute cumulative sum of new_deaths if 'new_deaths' in tb_results: new_deaths = tb_results['new_deaths'] # Handle both numpy arrays and pandas Series if hasattr(new_deaths, 'values'): new_deaths = new_deaths.values cum_deaths = np.cumsum(new_deaths) else: raise ValueError('No new_deaths or cum_deaths in tb results') # Compute annualized mortality rate mortality_rate = np.zeros_like(cum_deaths, dtype=float) for t in range(len(time)): t_date = time[t] t_prev_date = t_date - datetime.timedelta(days=365) t_prev = np.searchsorted(time, t_prev_date) if t_prev == len(time) or time[t_prev] > t_prev_date: t_prev = max(0, t_prev - 1) # Calculate difference in cumulative deaths over the year deaths_diff = cum_deaths[t] - cum_deaths[t_prev] pop = n_alive[t] mortality_rate[t] = (deaths_diff / pop) * 1e5 if pop > 0 else 0 return mortality_rate
[docs] def plot_tb_mortality_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp): """Plot annualized TB mortality rate (per 100,000) for all parameter combinations in a grid. The mortality rate at time t is the difference in cumulative TB deaths between t and t-365 days, divided by the population at t, times 100,000. """ import matplotlib.ticker as mtick nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] time = np.array(sim.results['timevec']) mortality_rate = compute_annualized_tb_mortality_rate(sim) ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] ax.plot(time, mortality_rate, color='red', label='Annual TB Mortality Rate', linewidth=2) ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True) if ax_idx == nrows - 1: ax.set_xlabel('Year') if j == 0: ax.set_ylabel('TB Mortality Rate (per 100,000)') if m == 0 and i == 0 and j == 0: ax.legend(fontsize=7) ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.0f')) plt.tight_layout() plt.suptitle('Annualized TB Mortality Rate (per 100,000)', fontsize=14, y=1.02) # Set consistent x-axis ticks for all subplots first_sim = sim_grid[0][0][0] time_years = np.array([d.year for d in first_sim.results['timevec']]) min_year = time_years.min() max_year = time_years.max() xticks = np.arange(min_year, max_year + 1, 20) for ax_row in axs: if isinstance(ax_row, np.ndarray): for ax in ax_row: ax.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax.set_xticklabels([str(year) for year in xticks], rotation=45) else: ax_row.set_xticks([datetime.date(year, 1, 1) for year in xticks]) ax_row.set_xticklabels([str(year) for year in xticks], rotation=45) filename = f"tb_mortality_rate_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def compute_age_distribution_at_year(sim, target_year=2022): """ Compute age distribution at a specific year Args: sim: Simulation object target_year: Year to compute age distribution for Returns: dict: Age distribution data with 5-year bins """ # Find the time index closest to target year time_years = np.array([d.year for d in sim.results['timevec']]) target_idx = np.argmin(np.abs(time_years - target_year)) # Get people alive at target time people = sim.people alive_mask = people.alive # Get ages at target time ages = people.age[alive_mask] # Define 5-year age bins: 0-4, 5-9, 10-14, 15-19, 20-24, 25-29, 30-34, 35-39, 40-44, 45-49, 50-54, 55-59, 60-64, 65-69, 70-74, 75-79, 80-84, 85-89, 90-94, 95+ age_bins = [(0, 4), (5, 9), (10, 14), (15, 19), (20, 24), (25, 29), (30, 34), (35, 39), (40, 44), (45, 49), (50, 54), (55, 59), (60, 64), (65, 69), (70, 74), (75, 79), (80, 84), (85, 89), (90, 94), (95, 200)] age_bin_labels = ['0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34', '35-39', '40-44', '45-49', '50-54', '55-59', '60-64', '65-69', '70-74', '75-79', '80-84', '85-89', '90-94', '95+'] total_population = len(ages) age_distribution = {} for i, (min_age, max_age) in enumerate(age_bins): # Count people in age bin age_mask = (ages >= min_age) & (ages <= max_age) count_in_bin = np.sum(age_mask) # Calculate percentage percentage = (count_in_bin / total_population) * 100 if total_population > 0 else 0 age_distribution[age_bin_labels[i]] = { 'count': count_in_bin, 'percentage': percentage } return age_distribution
[docs] def plot_population_pyramid_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp, target_year=2022): """ Plot population pyramids for each parameter combination showing age distribution at target year Args: sim_grid: 3D grid of simulation results beta_vals: Array of beta values rel_sus_vals: Array of relative susceptibility values tb_mortality_vals: Array of TB mortality values timestamp: Timestamp for filename target_year: Year to compute age distribution for (default: 2022) """ nrows = len(tb_mortality_vals) * len(rel_sus_vals) ncols = len(beta_vals) fig, axs = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3 * nrows), sharex=True, sharey=True) # Define age bins and labels age_bin_labels = ['0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34', '35-39', '40-44', '45-49', '50-54', '55-59', '60-64', '65-69', '70-74', '75-79', '80-84', '85-89', '90-94', '95+'] for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): sim = sim_grid[m][i][j] # Compute age distribution age_dist = compute_age_distribution_at_year(sim, target_year) # Extract percentages for plotting percentages = [age_dist[label]['percentage'] for label in age_bin_labels] ax_idx = m * len(rel_sus_vals) + i ax = axs[ax_idx][j] if nrows > 1 else axs[j] # Create horizontal bar plot (population pyramid) y_pos = np.arange(len(age_bin_labels)) bars = ax.barh(y_pos, percentages, color='skyblue', alpha=0.7) # Customize the plot ax.set_yticks(y_pos) ax.set_yticklabels(age_bin_labels) ax.set_xlabel('Percentage of Population') ax.set_title(f'β={beta:.3f}, rel_sus={rel_sus:.2f}, mort={tb_mortality:.1e}') ax.grid(True, alpha=0.3) # Add percentage labels on bars for bar, percentage in zip(bars, percentages): if percentage > 0.5: # Only show label if percentage is significant ax.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2, f'{percentage:.1f}%', va='center', ha='left', fontsize=8) # Set consistent x-axis limits ax.set_xlim(0, max(percentages) * 1.2) if ax_idx == nrows - 1: ax.set_xlabel('Percentage of Population') if j == 0: ax.set_ylabel('Age Group') plt.tight_layout() plt.suptitle(f'Population Age Distribution at {target_year}', fontsize=14, y=1.02) filename = f"population_pyramid_grid_{timestamp}.pdf" plt.savefig(get_output_path(filename), dpi=300, bbox_inches='tight') plt.show()
[docs] def run_sim(beta, rel_sus_latentslow, tb_mortality, seed=0, years=200, n_agents=1000): # 8000 start_year = 1850 # 1750 sim_pars = dict( dt=ss.days(30), start=ss.date(f'{start_year}-01-01'), stop=ss.date(f'{start_year + years}-01-01'), rand_seed=seed, verbose=0, ) # demog = [ss.Births(pars=dict(birth_rate=20)), ss.Deaths(pars=dict(death_rate=1))] # people = ss.People(n_agents=n_agents) # To do: Add time-varying birth rate and age-, sex-, year-specific mortality # Try different possible paths for the data files possible_cbr_paths = [ '../tbsim/data/South_Africa_CBR.csv', # Added correct relative path '../data/South_Africa_CBR.csv', 'tbsim/data/South_Africa_CBR.csv', 'data/South_Africa_CBR.csv', ] possible_asmr_paths = [ '../tbsim/data/South_Africa_ASMR.csv', # Added correct relative path '../data/South_Africa_ASMR.csv', 'tbsim/data/South_Africa_ASMR.csv', 'data/South_Africa_ASMR.csv', ] # Find the correct CBR path cbr_path = None for path in possible_cbr_paths: if os.path.exists(path): cbr_path = path break if cbr_path is None: raise FileNotFoundError(f"Could not find South_Africa_CBR.csv in any of the expected locations: {possible_cbr_paths}") # Find the correct ASMR path asmr_path = None for path in possible_asmr_paths: if os.path.exists(path): asmr_path = path break if asmr_path is None: raise FileNotFoundError(f"Could not find South_Africa_ASMR.csv in any of the expected locations: {possible_asmr_paths}") cbr = pd.read_csv(cbr_path) # Crude birth rate per 1000 asmr = pd.read_csv(asmr_path) # Age-specific mortality rate demog = [ ss.Births(birth_rate=cbr, dt=ss.days(30)), ss.Deaths(death_rate=asmr, dt=ss.days(30), rate_units=1), # rate_units=1 = per person-year ] people = make_people(n_agents=n_agents) tb_pars = dict( beta=ss.probpermonth(beta), # ss.prob(beta), init_prev=ss.bernoulli(p=0.10), # Higher initial prevalence for South Africa context rel_sus_latentslow=rel_sus_latentslow, p_latent_fast=ss.bernoulli(p=0.1), # Base fast progressor fraction (will be overridden by age-specific intervention) # South Africa-specific adjustments rate_LS_to_presym=ss.perday(5e-5), # Slightly higher progression for HIV context rate_LF_to_presym=ss.perday(8e-3), # Higher fast progression rate rate_active_to_clear=ss.perday(1.5e-4), # Lower clearance rate (more persistent) rate_smpos_to_dead=ss.perday(tb_mortality), rate_exptb_to_dead=ss.perday(0.15 * tb_mortality), rate_smneg_to_dead=ss.perday(0.3 * tb_mortality), ) tb = cf.make_tb_comorbidity(tb_pars=tb_pars) # Add HIV for South Africa context (critical for TB dynamics) hiv_pars = dict( init_prev=ss.bernoulli(p=0.00), # Start with no HIV, will be added via intervention init_onart=ss.bernoulli(p=0.00), ) hiv = cf.make_hiv_comorbidity(hiv_pars=hiv_pars) net = ss.RandomNet(pars=dict(n_contacts=ss.poisson(lam=5), dur=0)) # Add TB-HIV connector to model coinfection effects with increased progression rates # Higher multipliers to get steeper TB prevalence increase from 1990 onwards # Increased by 50% from previous values tb_hiv_connector = cf.make_tb_hiv_connector(pars=dict( acute_multiplier=4.5, # Increased from 3.0 to 4.5 (50% higher) latent_multiplier=7.5, # Increased from 5.0 to 7.5 (50% higher) aids_multiplier=12.0, # Increased from 8.0 to 12.0 (50% higher) )) # Add custom HIV intervention with gradual ramp-up based on van Schalkwyk et al. 2021 data for eThekwini hiv_intervention = GradualHIVIntervention(pars=dict( percent_on_ART=0.50, # 50% of HIV-positive individuals on ART start=ss.date('1990-01-01'), # Start from 1990 when HIV epidemic began stop=ss.date(f'{start_year + years}-01-01'), )) # Add health-seeking behavior intervention (90-day average delay - slower for better burn-in) # Rate = 1/90 days = 0.011 per day health_seeking = HealthSeekingBehavior(pars=dict( initial_care_seeking_rate=ss.perday(1/120), # 90-day average delay for slower case detection start=ss.date(f'{start_year}-01-01'), stop=ss.date(f'{start_year + years}-01-01'), single_use=True, )) # Add TB diagnostic intervention (60% sensitivity - less effective for better burn-in) tb_diagnostic = TBDiagnostic(pars=dict( coverage=ss.bernoulli(0.7, strict=False), # 70% coverage - not everyone gets tested sensitivity=0.50, # 60% sensitivity - less effective case detection specificity=0.95, # 95% specificity (standard) reset_flag=False, care_seeking_multiplier=1.0, # 2.0 to encourage retries for false negatives )) # Add TB treatment intervention (70% success rate - less effective for better burn-in) tb_treatment = TBTreatment(pars=dict( treatment_success_rate=0.70, # 70% treatment success rate - less effective treatment reseek_multiplier=1.0, # 2.0 to encourage retries for treatment failures reset_flags=True, # Reset diagnostic flags after treatment failure )) # Add age-dependent TB progression intervention age_tb_progression = AgeDependentTBProgression(pars=dict( age_0_4_multiplier=2.0, # 2x progression for 0-4 year olds age_5_14_multiplier=0.5, # 0.5x progression for 5-14 year olds age_15plus_multiplier=1.0, # 1x progression for 15+ year olds )) # Combine all interventions all_interventions = [hiv_intervention, health_seeking, tb_diagnostic, tb_treatment, age_tb_progression] sim = ss.Sim( people=people, diseases=[tb, hiv], networks=net, demographics=demog, connectors=tb_hiv_connector, # Pass connector directly, not in a list interventions=all_interventions, # Combined interventions list pars=sim_pars, ) sim.run() return sim
[docs] def refined_sweep(beta_vals, rel_sus_vals, tb_mortality_vals): # This function performs a parameter sweep over beta and relative susceptibility values # For each parameter combination, it runs a TB simulation and generates plots showing: # - Active TB prevalence over time (blue line) # - Target 1% prevalence threshold (red dotted line) # - 2018 South Africa data point (red dot) # Each subplot shows results for a specific (beta, rel_sus) parameter combination timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M") # e.g., 2025_06_24_0330 sim_grid = [[[None for _ in beta_vals] for _ in rel_sus_vals] for _ in tb_mortality_vals] results = {} total_runs = len(beta_vals) * len(rel_sus_vals) * len(tb_mortality_vals) for m, tb_mortality in enumerate(tb_mortality_vals): for i, rel_sus in enumerate(rel_sus_vals): for j, beta in enumerate(beta_vals): scen_key = f'beta={beta:.3f}_rel_sus={rel_sus:.2f}_mort={tb_mortality:.1e}' print(f"▶️ Running simulation {scen_key} ({m},{i},{j})/{total_runs}") sim = run_sim(beta=beta, rel_sus_latentslow=rel_sus, tb_mortality=tb_mortality) sim_grid[m][i][j] = sim results[scen_key] = sim.results.flatten() # Use common_functions.plot_results to plot all scenario results # Note: This function plots various metrics with the following units/definitions: # - 'active': Active TB cases (count of people with active TB disease) # - 'latent': Latent TB cases (count of people with latent TB infection) # - 'incidence': New TB infections per time step (count of new cases) # - 'prevalence': TB prevalence as fraction of total population (0-1) # - 'sought': People who sought care for TB symptoms (count) # - 'eligible': People eligible for care-seeking (active TB cases, count) # - 'tested': People who received diagnostic testing (count) # - 'diagnosed': People diagnosed with TB (count) # - 'treated': People who started TB treatment (count) # - 'success': Successful TB treatment completions (count) # - 'failure': Failed TB treatment attempts (count) cf.plot_results(results, dark=False) # Optionally, keep the original grid plots if desired plot_total_population_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_population_pyramid_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp, target_year=2022) plot_hiv_metrics_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_health_seeking_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_diagnostic_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_cumulative_diagnostic_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_treatment_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_cumulative_treatment_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_tb_sweep_with_data(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_annualized_infection_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_age_prevalence_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_hiv_tb_coinfection_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_case_notification_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) plot_tb_mortality_rate_grid(sim_grid, beta_vals, rel_sus_vals, tb_mortality_vals, timestamp) # Return the simulation grid for use by other functions return sim_grid
if __name__ == '__main__': # Setup for TB prevalence sweeps # This section configures the parameter ranges and executes the sweep analysis # Record start time start_wallclock = time.time() start_datetime = datetime.datetime.now() print(f"Sweep started at {start_datetime.strftime('%Y-%m-%d %H:%M:%S')}") # Plot population demographics # Run sweep # Reduced to 2 parameter combinations for faster runtime beta_range = np.array([0.025, 0.035]) # Higher infectiousness range 0.025-0.035 rel_sus_range = np.array([0.15]) # Single value for reinfection susceptibility tb_mortality_range = [3e-4] # Single value for TB mortality refined_sweep(beta_range, rel_sus_range, tb_mortality_range) end_wallclock = time.time() end_datetime = datetime.datetime.now() elapsed_minutes = (end_wallclock - start_wallclock) / 60 print(f"Sweep finished at {end_datetime.strftime('%Y-%m-%d %H:%M:%S')}") print(f"Total runtime: {elapsed_minutes:.1f} minutes") # Uncomment the line below to run a quick test of the health-seeking and diagnostic integration # test_health_seeking_diagnostic_integration()
[docs] def test_hiv_integration(): """Simple test function to verify HIV integration works correctly""" print("Testing HIV integration...") # Run a simple simulation with HIV sim = run_sim(beta=0.003, rel_sus_latentslow=0.05, tb_mortality=4e-4, years=50, n_agents=500) # Debug HIV results debug_hiv_results(sim) # Test HIV prevalence calculation hiv_prev = compute_hiv_prevalence(sim) print(f"HIV prevalence at final timestep: {hiv_prev[-1]:.3f}") # Test HIV-positive TB prevalence calculation hiv_tb_prev = compute_hiv_positive_tb_prevalence(sim) print(f"HIV-positive TB prevalence at final timestep: {hiv_tb_prev[-1]:.3f}") print("HIV integration test completed.")
# Uncomment the line below to run the HIV integration test # test_hiv_integration()
[docs] def test_health_seeking_diagnostic_integration(): """Test function to verify health-seeking and diagnostic integration works correctly""" print("Testing health-seeking and diagnostic integration...") # Run a simple simulation with health-seeking and diagnostic sim = run_sim(beta=0.003, rel_sus_latentslow=0.05, tb_mortality=4e-4, years=50, n_agents=500) # Check if health-seeking results are available try: hsb = sim.results['healthseekingbehavior'] print(f"✓ Health-seeking results available") print(f" - Final new sought care: {hsb['new_sought_care'].values[-1]}") print(f" - Final cumulative sought care: {hsb['n_sought_care'].values[-1]}") print(f" - Final eligible: {hsb['n_eligible'].values[-1]}") except KeyError: print("✗ Health-seeking results not found") # Check if diagnostic results are available try: tbdiag = sim.results['tbdiagnostic'] print(f"✓ Diagnostic results available") print(f" - Final tested: {tbdiag['n_tested'].values[-1]}") print(f" - Final test positive: {tbdiag['n_test_positive'].values[-1]}") print(f" - Final test negative: {tbdiag['n_test_negative'].values[-1]}") print(f" - Cumulative test positive: {tbdiag['cum_test_positive'].values[-1]}") print(f" - Cumulative test negative: {tbdiag['cum_test_negative'].values[-1]}") except KeyError: print("✗ Diagnostic results not found") # Check treatment results try: tbtx = sim.results['tbtreatment'] print(f"✓ Treatment results available") print(f" - Final treated: {tbtx['n_treated'].values[-1]}") print(f" - Final treatment success: {tbtx['n_treatment_success'].values[-1]}") print(f" - Final treatment failure: {tbtx['n_treatment_failure'].values[-1]}") print(f" - Cumulative treated: {tbtx['cum_treated'].values[-1]}") print(f" - Cumulative treatment success: {tbtx['cum_treatment_success'].values[-1]}") print(f" - Cumulative treatment failure: {tbtx['cum_treatment_failure'].values[-1]}") except KeyError: print("✗ Treatment results not found") print("Health-seeking and diagnostic integration test completed.")
# Uncomment the line below to run the health-seeking and diagnostic integration test # test_health_seeking_diagnostic_integration()