#!/usr/bin/env python3
"""
Household Network Plotting Script
This script provides comprehensive visualization tools for household networks in TBSim.
It includes basic network plots, advanced statistics, and simulation examples.
Usage:
python scripts/plot_household_networks.py [--example basic|advanced|simulation|all]
"""
import sys
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
# Add the parent directory to the path to import tbsim
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import tbsim as mtb
import starsim as ss
from tbsim.networks import HouseholdNet, plot_household_structure
[docs]
def plot_household_network_basic(households, title="Household Network", save_path=None):
"""
Create a visualization of household networks using NetworkX with dark theme.
Args:
households: List of lists, where each inner list contains agent UIDs in a household
title: Title for the plot
save_path: Optional path to save the plot
"""
# Set dark theme
plt.style.use('dark_background')
# Create a NetworkX graph
G = nx.Graph()
# Add nodes (all agents)
all_agents = [agent for hh in households for agent in hh]
G.add_nodes_from(all_agents)
# Create color palette
household_colors = plt.cm.viridis(np.linspace(0, 1, len(households)))
# Create figure with dark background
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_facecolor('#1a1a1a')
fig.patch.set_facecolor('#1a1a1a')
# Add edges for each household (complete graph)
for hh_idx, household in enumerate(households):
if len(household) > 1:
# Create complete graph for this household
for i in range(len(household)):
for j in range(i + 1, len(household)):
G.add_edge(household[i], household[j],
household=hh_idx, color=household_colors[hh_idx])
# Create layout with better spacing
pos = nx.spring_layout(G, seed=42, k=3, iterations=100)
# Draw edges first (behind nodes)
for hh_idx, household in enumerate(households):
household_edges = [(u, v) for u, v in G.edges()
if u in household and v in household]
nx.draw_networkx_edges(G, pos,
edgelist=household_edges,
edge_color=household_colors[hh_idx],
width=3,
alpha=0.6,
style='solid',
ax=ax)
for hh_idx, household in enumerate(households):
node_positions = {node: pos[node] for node in household if node in pos}
node_sizes = [400 + 100 * len(household)] * len(household)
nx.draw_networkx_nodes(G, node_positions,
nodelist=household,
node_color=[household_colors[hh_idx]] * len(household),
node_size=node_sizes,
alpha=0.9,
edgecolors='white',
linewidths=1,
ax=ax)
# Draw labels with enhanced styling
nx.draw_networkx_labels(G, pos,
font_size=14,
font_weight='bold',
font_color='white',
ax=ax)
legend_elements = []
for hh_idx, household in enumerate(households):
legend_elements.append(plt.Line2D([0], [0], marker='o', color='w',
markerfacecolor=household_colors[hh_idx],
markersize=15, alpha=0.9, markeredgecolor='white',
markeredgewidth=1,
label=f'Household {hh_idx + 1} (n={len(household)})'))
legend = ax.legend(handles=legend_elements, loc='upper right',
bbox_to_anchor=(1.15, 1),
frameon=True,
fancybox=True,
shadow=True,
facecolor='#2a2a2a',
edgecolor='#404040',
fontsize=10)
ax.set_title(title, fontsize=20, fontweight='bold', color='white', pad=20)
# Add subtle grid for reference
ax.grid(True, alpha=0.1, color='white', linestyle='-', linewidth=0.5)
# Remove axes
ax.axis('off')
# Add text box with network statistics
total_agents = sum(len(hh) for hh in households)
total_households = len(households)
mean_size = np.mean([len(hh) for hh in households])
stats_text = f"Total Agents: {total_agents}\n"
stats_text += f"Total Households: {total_households}\n"
stats_text += f"Mean Size: {mean_size:.1f}"
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round,pad=0.5', facecolor='#404040', alpha=0.8, edgecolor='#606060'),
fontsize=10, color='white')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='#1a1a1a')
print(f"Plot saved to: {save_path}")
plt.show()
return G
[docs]
def plot_household_network_advanced(households, title="Advanced Household Network", save_path=None):
"""
Create a advanced visualization with household clustering and enhanced statistics.
Args:
households: List of lists, where each inner list contains agent UIDs in a household
title: Title for the plot
save_path: Optional path to save the plot
"""
# Set dark theme
plt.style.use('dark_background')
# Create figure with dark background
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
fig.patch.set_facecolor('#1a1a1a')
ax1.set_facecolor('#1a1a1a')
ax2.set_facecolor('#1a1a1a')
# Left plot: Enhanced Network visualization
G = nx.Graph()
all_agents = [agent for hh in households for agent in hh]
G.add_nodes_from(all_agents)
# Add household membership as node attribute
for hh_idx, household in enumerate(households):
for agent in household:
G.nodes[agent]['household'] = hh_idx
G.nodes[agent]['household_size'] = len(household)
# Add edges within households
for hh_idx, household in enumerate(households):
if len(household) > 1:
for i in range(len(household)):
for j in range(i + 1, len(household)):
G.add_edge(household[i], household[j], household=hh_idx)
# Create color palette
household_colors = plt.cm.plasma(np.linspace(0, 1, len(households)))
# Create enhanced layout with household clustering
pos = {}
# Position households in a more pattern
household_angles = np.linspace(0, 2*np.pi, len(households), endpoint=False)
for hh_idx, household in enumerate(households):
# Base position for this household with varying radius
radius = 4 + 0.5 * len(household) # Larger households get more space
hh_x = radius * np.cos(household_angles[hh_idx])
hh_y = radius * np.sin(household_angles[hh_idx])
if len(household) == 1:
pos[household[0]] = (hh_x, hh_y)
else:
# Arrange household members in a more pattern
if len(household) == 2:
# Line arrangement for couples
pos[household[0]] = (hh_x - 0.3, hh_y)
pos[household[1]] = (hh_x + 0.3, hh_y)
else:
# Circular arrangement for larger households
agent_angles = np.linspace(0, 2*np.pi, len(household), endpoint=False)
inner_radius = 0.8 + 0.2 * len(household)
for agent_idx, agent in enumerate(household):
pos[agent] = (hh_x + inner_radius * np.cos(agent_angles[agent_idx]),
hh_y + inner_radius * np.sin(agent_angles[agent_idx]))
# Draw edges first with enhanced styling
for hh_idx, household in enumerate(households):
household_edges = [(u, v) for u, v in G.edges()
if u in household and v in household]
nx.draw_networkx_edges(G, pos,
edgelist=household_edges,
edge_color=household_colors[hh_idx],
width=4,
alpha=0.7,
style='solid',
ax=ax1)
# Draw nodes with enhanced styling
for hh_idx, household in enumerate(households):
node_positions = {node: pos[node] for node in household}
# Dynamic node sizing based on household size
base_size = 600
size_multiplier = 1 + 0.3 * len(household)
node_sizes = [base_size * size_multiplier for _ in household]
nx.draw_networkx_nodes(G, node_positions,
nodelist=household,
node_color=[household_colors[hh_idx]] * len(household),
node_size=node_sizes,
alpha=0.9,
edgecolors='white',
linewidths=1,
ax=ax1)
# Draw labels with enhanced styling
nx.draw_networkx_labels(G, pos, font_size=12, font_weight='bold',
font_color='white', ax=ax1)
# Add household labels
for hh_idx, household in enumerate(households):
if len(household) > 0:
# Calculate household center
hh_center_x = np.mean([pos[agent][0] for agent in household])
hh_center_y = np.mean([pos[agent][1] for agent in household])
# Add household label
ax1.text(hh_center_x, hh_center_y + 1.5, f'HH{hh_idx+1}',
fontsize=12, fontweight='bold', color=household_colors[hh_idx],
ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='#2a2a2a', alpha=0.8))
ax1.set_title('Enhanced Household Network Structure', fontsize=18, fontweight='bold', color='white', pad=20)
ax1.axis('off')
ax1.grid(True, alpha=0.1, color='white', linestyle='-', linewidth=0.5)
# Right plot: Enhanced statistics with dark theme
household_sizes = [len(hh) for hh in households]
size_counts = {size: household_sizes.count(size) for size in set(household_sizes)}
# Create gradient colors for bars
bar_colors = plt.cm.viridis(np.linspace(0, 1, len(size_counts)))
bars = ax2.bar(size_counts.keys(), size_counts.values(),
color=bar_colors, alpha=0.8, edgecolor='white', linewidth=1)
# Add value labels on bars with enhanced styling
for bar in bars:
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.05,
f'{int(height)}', ha='center', va='bottom', fontweight='bold',
color='white', fontsize=12)
ax2.set_xlabel('Household Size', fontsize=12, color='white')
ax2.set_ylabel('Number of Households', fontsize=12, color='white')
ax2.set_title('Household Size Distribution', fontsize=18, fontweight='bold', color='white', pad=20)
# Enhanced grid styling
ax2.grid(True, alpha=0.2, color='white', linestyle='-', linewidth=0.5)
ax2.tick_params(colors='white', labelsize=12)
# Enhanced summary statistics box
total_agents = sum(household_sizes)
mean_size = np.mean(household_sizes)
max_size = max(household_sizes)
min_size = min(household_sizes)
std_size = np.std(household_sizes)
stats_text = f"Total Agents: {total_agents}\n"
stats_text += f"Total Households: {len(households)}\n"
stats_text += f"Mean Size: {mean_size:.1f}\n"
stats_text += f"Std Dev: {std_size:.1f}\n"
stats_text += f"Size Range: {min_size}-{max_size}"
ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round,pad=0.5', facecolor='#404040', alpha=0.9,
edgecolor='#606060', linewidth=1),
fontsize=10, color='white', fontweight='bold')
# Add connection lines between households in the network
for i in range(len(households)):
for j in range(i+1, len(households)):
# Calculate centers
hh1_center = np.mean([pos[agent] for agent in households[i]], axis=0)
hh2_center = np.mean([pos[agent] for agent in households[j]], axis=0)
# Draw subtle connection lines
ax1.plot([hh1_center[0], hh2_center[0]], [hh1_center[1], hh2_center[1]],
'--', color='white', alpha=0.1, linewidth=1)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='#1a1a1a')
print(f"Plot saved to: {save_path}")
plt.show()
return G
[docs]
def demonstrate_household_network_simulation():
"""
Demonstrate how to extract and plot household networks from a TBSim simulation.
"""
print("Creating household network simulation...")
# Define households
households = [
[0, 1, 2], # Family of 3
[3, 4], # Couple
[5, 6, 7, 8], # Family of 4
[9], # Single person
[10, 11, 12] # Family of 3
]
# Create network
household_net = HouseholdNet(hhs=households)
# Create simulation
sim = ss.Sim(
people=ss.People(n_agents=13),
networks=[household_net],
diseases=mtb.TB(),
pars=dict(start=ss.date(2000), stop=ss.date(2001), dt=ss.months(1))
)
# Run simulation
sim.run()
print(f"Simulation completed with {len(households)} households")
print(f"Total agents: {sum(len(hh) for hh in households)}")
print(f"Network edges: {len(household_net.edges.p1) if hasattr(household_net.edges, 'p1') else 'None'}")
# Plot the networks
plot_household_network_basic(households, "TBSim Household Network - Basic View")
plot_household_network_advanced(households, "TBSim Household Network - Advanced View")
return sim, household_net
[docs]
def create_sample_households():
"""Create sample household structures for demonstration."""
return {
'simple': [[0, 1, 2], [3, 4], [5, 6, 7, 8]],
'complex': [
[0, 1, 2, 3], # Large family
[4, 5], # Couple
[6], # Single
[7, 8, 9], # Family of 3
[10, 11], # Another couple
[12, 13, 14, 15, 16] # Very large family
],
'realistic': [
[0, 1, 2, 3, 4], # Extended family
[5, 6], # Young couple
[7], # Elderly single
[8, 9, 10], # Nuclear family
[11, 12, 13, 14], # Large family
[15, 16], # Couple
[17, 18, 19], # Family with children
[20] # Single person
],
'': [
[0, 1, 2, 3, 4, 5], # Extended family with grandparents
[6, 7], # Young couple
[8], # Elderly single
[9, 10, 11], # Nuclear family
[12, 13, 14, 15], # Large family
[16, 17], # Couple
[18, 19, 20], # Family with children
[21], # Single person
[22, 23, 24, 25, 26, 27, 28], # Very large extended family
[29, 30, 31], # Another nuclear family
[32, 33, 34, 35], # Family of 4
[36, 37], # Another couple
[38, 39, 40, 41, 42] # Large family
],
'community': [
[0, 1, 2, 3, 4, 5, 6], # Large extended family
[7, 8], # Couple
[9], # Single
[10, 11, 12], # Nuclear family
[13, 14, 15, 16], # Large family
[17, 18], # Couple
[19, 20, 21], # Family with children
[22], # Single person
[23, 24, 25, 26, 27, 28, 29, 30], # Very large family
[31, 32, 33], # Nuclear family
[34, 35, 36, 37], # Family of 4
[38, 39], # Couple
[40, 41, 42, 43, 44], # Large family
[45, 46, 47], # Another nuclear family
[48, 49, 50, 51, 52, 53], # Extended family
[54, 55], # Young couple
[56, 57, 58, 59], # Family with teenagers
[60] # Elderly single
]
}
[docs]
def main():
"""Main function to run the household network plotting script."""
parser = argparse.ArgumentParser(description='Plot household networks from TBSim')
parser.add_argument('--example', choices=['basic', 'advanced', 'simulation', '', 'community', 'all'],
default='all', help='Type of example to run')
parser.add_argument('--save', action='store_true',
help='Save plots to files')
parser.add_argument('--output-dir', default='scripts/results',
help='Output directory for saved plots')
args = parser.parse_args()
# Create output directory if it doesn't exist
if args.save:
os.makedirs(args.output_dir, exist_ok=True)
print("Household Network Plotting Script")
print("=" * 50)
print("🎨 Enhanced Dark Theme Visualizations")
print("=" * 50)
# Get sample households
samples = create_sample_households()
if args.example in ['basic', 'all']:
print(f"\n🎯 Example 1: Basic Network Visualization")
households = samples['simple']
save_path = os.path.join(args.output_dir, 'household_network_basic_dark.png') if args.save else None
plot_household_network_basic(households, " Household Network", save_path)
if args.example in ['advanced', 'all']:
print(f"\n🚀 Example 2: Advanced Network Visualization with Enhanced Statistics")
households = samples['complex']
save_path = os.path.join(args.output_dir, 'household_network_advanced_dark.png') if args.save else None
plot_household_network_advanced(households, "Advanced Household Network Analysis", save_path)
if args.example in ['', 'all']:
print(f"\n🌟 Example 3: Community Network")
households = samples['']
save_path = os.path.join(args.output_dir, 'household_network__dark.png') if args.save else None
plot_household_network_advanced(households, " Community Network", save_path)
if args.example in ['community', 'all']:
print(f"\n🏘️ Example 4: Large Community Network Analysis")
households = samples['community']
save_path = os.path.join(args.output_dir, 'household_network_community_dark.png') if args.save else None
plot_household_network_advanced(households, "Large Community Network Analysis", save_path)
if args.example in ['simulation', 'all']:
print(f"\n⚡ Example 5: Full TBSim Simulation with Household Networks")
sim, net = demonstrate_household_network_simulation()
# Test the built-in plotting function with dark theme
print(f"\n🔧 Example 6: Built-in TBSim Plotting Function (Dark Theme)")
households = samples['realistic']
save_path = os.path.join(args.output_dir, 'household_network_builtin_dark.png') if args.save else None
# Apply dark theme to built-in function
plt.style.use('dark_background')
G = plot_household_structure(households, 'TBSim Built-in Household Plot (Dark Theme)')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='#1a1a1a')
print(f"Built-in plot saved to: {save_path}")
print("\n" + "=" * 50)
print(" Demonstration Complete!")
print("=" * 50)
print(f" Available Enhanced Plotting Functions:")
print(" plot_household_network_basic(): network visualization with dark theme")
print(" plot_household_network_advanced(): Advanced visualization with enhanced statistics")
print(" plot_household_structure(): Built-in TBSim function (dark theme applied)")
print(" demonstrate_household_network_simulation(): Full simulation example")
print("\n Features:")
print(" • Dark theme with color palettes")
print(" • Enhanced node and edge styling")
print(" • Dynamic sizing based on household characteristics")
print(" • Advanced statistics and visual elements")
print(" • High-resolution output with proper dark backgrounds")
if __name__ == "__main__":
main()