#!/usr/bin/env python3
# Copyright 2025 Kris Sigurdson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Check for --help flag at the very beginning
import sys
showing_help = '--help' in sys.argv or '-h' in sys.argv
# Import other modules
import numpy as np
import matplotlib
matplotlib.use('Agg', force=True)
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.stats import energy_distance
from scipy.interpolate import UnivariateSpline
import math
import os
import glob
import re
import gc
import psutil
import struct
import imageio.v2 as imageio # Use imageio v2 for get_writer
from io import BytesIO
import multiprocessing as mp
import matplotlib.colors as colors
from matplotlib.colors import LogNorm
from scipy.stats import binned_statistic
import argparse
import signal
import contextlib
import time
import atexit
from tqdm import tqdm
import logging
import shutil
import traceback
import threading
[docs]
def find_project_root():
"""
Determines the project root directory.
Returns
-------
str
Path to the project root directory.
Notes
-----
The project root is identified by the presence of a 'src' directory.
Uses path normalization for cross-platform compatibility.
"""
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
if os.path.basename(script_dir) == 'python':
potential_root = os.path.dirname(os.path.dirname(script_dir))
else:
potential_root = script_dir
potential_root = os.path.normpath(potential_root)
if os.path.isdir(os.path.join(potential_root, 'src')):
return potential_root
parent_dir = os.path.dirname(potential_root)
if os.path.isdir(os.path.join(parent_dir, 'src')):
return parent_dir
grandparent_dir = os.path.dirname(parent_dir)
if os.path.isdir(os.path.join(grandparent_dir, 'src')):
return grandparent_dir
print(f"Warning: Could not verify project root via 'src' directory. Using derived directory: {potential_root}", file=sys.stderr)
return potential_root
# Set up project paths
PROJECT_ROOT = find_project_root()
DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
# Add the project src directory to the module search path if needed
src_dir = os.path.join(PROJECT_ROOT, 'src')
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
# Define global variables
suffix = "" # Will be set in main()
start_snap = 0
end_snap = 0
step_snap = 1
duration = 100.0 # Default frame duration in ms (for fps=10)
section_delay = 0.0 # No delay by default (will be set to 5.0 if paced_mode is enabled)
progress_delay = 0.0 # No delay by default (will be set to 2.0 if paced_mode is enabled)
enable_logging = False # Default: don't log info messages unless --log is specified
paced_mode = False # When true, enables delays between sections with timers
[docs]
def show_section_delay(delay_seconds):
"""
Display a visual timer with dots for section transitions.
Parameters
----------
delay_seconds : float
The number of seconds to delay/pace
Notes
-----
Shows an upward counting timer with dots accumulating each second.
Displays the full delay_seconds at the end to ensure consistent timing.
Allows the user to skip the delay with spacebar or pause/resume with 'p'.
"""
try:
import msvcrt # Windows
def kbhit():
return msvcrt.kbhit()
def getch():
return msvcrt.getch().decode('utf-8').lower()
is_windows = True
except ImportError:
try:
import termios, fcntl, os, select # Unix/Linux/MacOS
is_windows = False
# Save the terminal settings
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
# Setup non-blocking input
def setup_nonblocking():
new = termios.tcgetattr(fd)
new[3] = new[3] & ~termios.ICANON
new[3] = new[3] & ~termios.ECHO
termios.tcsetattr(fd, termios.TCSANOW, new)
oldflags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, oldflags | os.O_NONBLOCK)
def restore_terminal():
termios.tcsetattr(fd, termios.TCSAFLUSH, old_settings)
def kbhit():
dr, dw, de = select.select([sys.stdin], [], [], 0)
return len(dr) > 0
def getch():
if kbhit():
ch = sys.stdin.read(1)
return ch.lower()
return None
setup_nonblocking()
except ImportError:
# If neither input method is available, fallback to basic behavior
def kbhit():
return False
def getch():
return None
def restore_terminal():
pass
is_windows = False
start_time = time.time()
dots = ""
paused = False
pause_start_time = 0
pause_elapsed_time = 0 # Store the elapsed time when entering pause
total_paused_time = 0
# Helper to format display string
def get_display_text(elapsed_time, dots_str, pause_status):
# Fix the position of the message by padding the dots to a consistent length
max_dots = int(delay_seconds) + 1 # Maximum possible dots
dots_padding = " " * (max_dots - len(dots_str))
# Format base display string - time goes right after "for"
base_display = f"\rPacing output for {elapsed_time:.2f}s "
# Calculate message length and ensure padding is sufficient
pause_msg = "Press Spacebar to Skip, P to Resume"
run_msg = "Press Spacebar to Skip, P to Pause"
# Ensure extra padding is present for the longer message
max_msg_length = max(len(pause_msg), len(run_msg)) + 4 # +4 for brackets and buffer
# Define spacing constants
status_text = "[PAUSED] "
status_padding = " " * len(status_text)
# Add appropriate status or padding after the time
if pause_status:
# When paused, show [PAUSED] after the time
display_with_status = base_display + status_text
msg = pause_msg
else:
# When not paused, add padding to keep elements position consistent
display_with_status = base_display + status_padding
msg = run_msg # No extra space here - it will be added after the bracket
# Place message 14 spaces to the right and ensure it stays in place
guidance = " [" + msg + "] " # Added a space after the closing bracket for padding
# Complete the display string with dots and guidance
return f"{display_with_status}{dots_str}{dots_padding}{guidance}"
# Initial display with empty dots and consistent padding
initial_dots = "." # Start with one dot
sys.stdout.write("\n")
# Always start in the unpaused state for the initial display
sys.stdout.write(get_display_text(0.0, initial_dots, False))
sys.stdout.flush()
try:
# Keep running until the delay duration is reached or exceeded
while True:
if kbhit():
key = getch()
if key == ' ': # Spacebar always skips entirely, even when paused
# Clear the entire line and return immediately
sys.stdout.write("\r\033[2K")
sys.stdout.flush()
return
elif key == 'p': # 'p' toggles pause/resume
if paused:
# Resuming - add elapsed pause time to total
total_paused_time += time.time() - pause_start_time
paused = False
else:
# Pausing - record the time pause began and save current elapsed time
pause_start_time = time.time()
# Calculate and store the current elapsed time to display during pause
pause_elapsed_time = time.time() - start_time - total_paused_time
paused = True
if not paused:
# Only update time when not paused
current_time = time.time()
elapsed = current_time - start_time - total_paused_time
if elapsed >= delay_seconds:
# The end of the delay has been reached
# Instead of showing a final display, clear the line completely (like spacebar skip)
sys.stdout.write("\r\033[2K")
sys.stdout.flush()
break
# Add a dot every second
new_dots = "." * (int(elapsed) + 1)
if new_dots != dots:
dots = new_dots
# Update display
display_content = get_display_text(elapsed, dots, paused)
sys.stdout.write("\r\033[2K" + truncate_and_pad_string(display_content))
sys.stdout.flush()
else:
# When paused, display the stored time and dots from when we entered pause
# Use the stored pause_elapsed_time which is frozen at the moment of pause
display_content_paused = get_display_text(pause_elapsed_time, dots, paused)
sys.stdout.write("\r\033[2K" + truncate_and_pad_string(display_content_paused))
sys.stdout.flush()
time.sleep(0.1)
# No newline after completing - line is already cleared
finally:
# Ensure terminal is restored if using Unix-style input
if not is_windows and 'restore_terminal' in locals():
restore_terminal()
[docs]
def show_progress_delay(delay_seconds):
"""
Display a minimal visual indicator during progress delay.
Parameters
----------
delay_seconds : float
The number of seconds to delay/pace
Notes
-----
Shows only dots accumulating with each second, without text.
"""
start_time = time.time()
dots = ""
while (time.time() - start_time) < delay_seconds:
elapsed = time.time() - start_time
# Add a dot every second
new_dots = "." * (int(elapsed) + 1)
if new_dots != dots:
dots = new_dots
# Truncate the dots string (unlikely to be needed but consistent)
sys.stdout.write(f"\r\033[2K{truncate_and_pad_string(dots)}")
sys.stdout.flush()
time.sleep(0.1)
sys.stdout.write("\r\033[2K") # Clear the dots
sys.stdout.flush()
# Global variables for animation data
mass_snapshots = []
density_snapshots = []
psi_snapshots = []
# Global variables for histogram data tracking
particles_original_count = 0
particles_final_original_count = 0
# Configure tqdm to work properly in all environments
tqdm.monitor_interval = 0 # Disable monitor thread to avoid issues
# --- TQDM Dynamic Formatting Constants ---
# Format Strings (No Bar, No Percentage for counter_tqdm line)
FMT_FULL = '{desc}: {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
FMT_NO_REM = '{desc}: {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}]'
FMT_NO_RATE = '{desc}: {n_fmt}/{total_fmt} [{elapsed}]'
FMT_COUNTS = '{desc}: {n_fmt}/{total_fmt}'
FMT_DESC = '{desc}'
# Descriptions (Full and Short)
DESCRIPTIONS = {
'proc_anim_frames': {
'full': "Processing animation frames", 'short': "Proc. anim frames"},
'encod_frames': {
'full': "Encoding frames", 'short': "Encod. frames"},
'preproc_phase': {
'full': "Preprocessing phase space data", 'short': "Preproc. phase data"},
'render_phase': {
'full': "Rendering phase space frames", 'short': "Render phase frames"},
'proc_sorted_snaps': {
'full': "Processing rank sorted snapshot files", 'short': "Proc. sorted snaps"},
'proc_unsorted_snaps': {
'full': "Processing unsorted snapshot files", 'short': "Proc. unsorted snaps"},
'proc_energy_series': {
'full': "Processing energy data for time series", 'short': "Proc. energy series"},
'gen_mass_frames': {
'full': "Generating mass profile frames", 'short': "Gen. mass frames"},
'gen_dens_frames': {
'full': "Generating density profile frames", 'short': "Gen. density frames"},
'gen_psi_frames': {
'full': "Generating psi profile frames", 'short': "Gen. psi frames"},
}
# Pre-calculated Thresholds (Min width needed for format with specific desc length)
# Based on: Counts=11, Elapsed=9, Remaining=9, Rate=14, Separators
TQDM_THRESHOLDS = {
'proc_anim_frames': { # Len 28/18
'full': {'full': 80, 'no_rem': 70, 'no_rate': 52, 'counts': 41, 'desc': 28},
'short': {'full': 70, 'no_rem': 60, 'no_rate': 42, 'counts': 31, 'desc': 18}},
'encod_frames': { # Len 15/13
'full': {'full': 67, 'no_rem': 57, 'no_rate': 39, 'counts': 28, 'desc': 15},
'short': {'full': 65, 'no_rem': 55, 'no_rate': 37, 'counts': 26, 'desc': 13}},
'preproc_phase': { # Len 30/20
'full': {'full': 82, 'no_rem': 72, 'no_rate': 54, 'counts': 43, 'desc': 30},
'short': {'full': 72, 'no_rem': 62, 'no_rate': 44, 'counts': 33, 'desc': 20}},
'render_phase': { # Len 28/19
'full': {'full': 80, 'no_rem': 70, 'no_rate': 52, 'counts': 41, 'desc': 28},
'short': {'full': 71, 'no_rem': 61, 'no_rate': 43, 'counts': 32, 'desc': 19}},
'proc_sorted_snaps': { # Len 36/18
'full': {'full': 88, 'no_rem': 78, 'no_rate': 60, 'counts': 49, 'desc': 36},
'short': {'full': 70, 'no_rem': 60, 'no_rate': 42, 'counts': 31, 'desc': 18}},
'proc_unsorted_snaps': { # Len 34/20
'full': {'full': 86, 'no_rem': 76, 'no_rate': 58, 'counts': 47, 'desc': 34},
'short': {'full': 72, 'no_rem': 62, 'no_rate': 44, 'counts': 33, 'desc': 20}},
'proc_energy_series': { # Len 38/20
'full': {'full': 90, 'no_rem': 80, 'no_rate': 62, 'counts': 51, 'desc': 38},
'short': {'full': 72, 'no_rem': 62, 'no_rate': 44, 'counts': 33, 'desc': 20}},
'gen_mass_frames': { # Len 30/16
'full': {'full': 82, 'no_rem': 72, 'no_rate': 54, 'counts': 43, 'desc': 30},
'short': {'full': 68, 'no_rem': 58, 'no_rate': 40, 'counts': 29, 'desc': 16}},
'gen_dens_frames': { # Len 32/19
'full': {'full': 84, 'no_rem': 74, 'no_rate': 56, 'counts': 45, 'desc': 32},
'short': {'full': 71, 'no_rem': 61, 'no_rate': 43, 'counts': 32, 'desc': 19}},
'gen_psi_frames': { # Len 29/15
'full': {'full': 81, 'no_rem': 71, 'no_rate': 53, 'counts': 42, 'desc': 29},
'short': {'full': 67, 'no_rem': 57, 'no_rate': 39, 'counts': 28, 'desc': 15}},
}
# Helper function to select format and description
# --- End TQDM Dynamic Formatting Constants ---
# Configure tqdm to match the custom progress bar format
tqdm_kwargs = {
'ascii': False, # Use Unicode characters for progress bars
'position': 0, # Always at position 0
'leave': True, # Leave the progress bar after completion
'miniters': 5, # Update every 5 iterations to reduce flickering
'dynamic_ncols': True, # Enable dynamic width adaptation
'ncols': None, # No fixed width override
'smoothing': 0.3, # Smoother progress updates
# 'bar_format': MUST BE ABSENT or None
}
# Handle keyboard interrupts gracefully
original_sigint_handler = signal.getsignal(signal.SIGINT)
[docs]
def signal_handler(sig, frame):
"""
Handle keyboard interrupts gracefully.
Parameters
----------
sig : int
Signal number
frame : frame
Current stack frame
Notes
-----
Restores the original signal handler to allow a forced exit
with a second Ctrl+C if needed.
"""
print("\nInterrupted by user. Cleaning up...")
# Restore original handler to allow a second Ctrl+C to force exit
signal.signal(signal.SIGINT, original_sigint_handler)
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler)
# Context manager to suppress stdout during progress bar updates
[docs]
@contextlib.contextmanager
def suppress_stdout():
"""
Context manager that suppresses stdout output.
Temporarily redirects stdout to a dummy file object that discards output.
Used to prevent unwanted console output during progress bar updates.
Yields
------
None
Control returns to the caller with stdout suppressed
Notes
-----
The original stdout is always restored when leaving the context,
even if an exception occurs.
Examples
--------
>>> with suppress_stdout():
... print("This won't be displayed")
"""
original_stdout = sys.stdout
class DummyFile:
def write(self, x): pass
def flush(self): pass
sys.stdout = DummyFile()
try:
yield
finally:
sys.stdout = original_stdout
# Helper function for processing animation frames (must be at module level for multiprocessing)
[docs]
def process_animation_frame(frame):
"""
Process a single animation frame for GIF output.
Parameters
----------
frame : numpy.ndarray
The frame image data
Returns
-------
numpy.ndarray
Processed frame ready for GIF encoding
"""
# Create BytesIO buffer for this frame
buf = BytesIO()
# Write frame to BytesIO buffer with explicit format (needed in v3)
imageio.imwrite(buf, frame, extension='.png', plugin='pillow')
# Reset buffer position to start
buf.seek(0)
# Read the frame back as an image
frame_img = imageio.imread(buf, index=0)
# Clean up buffer
buf.close()
return frame_img
# ANSI escape codes for cursor manipulation
HIDE_CURSOR = "\033[?25l" # Hide the cursor
SHOW_CURSOR = "\033[?25h" # Show the cursor
[docs]
def truncate_and_pad_string(text, fallback_width=100):
"""
Truncates a string to the terminal width and pads with spaces.
Parameters
----------
text : str
The string to potentially truncate and pad.
fallback_width : int, optional
The width to use if terminal size can't be determined, by default 100.
Returns
-------
str
The string, truncated and padded with spaces to the terminal width.
"""
try:
# Attempt to get the terminal size
columns = shutil.get_terminal_size().columns
width = columns
except OSError:
# If output is redirected or terminal size unavailable, use fallback
width = fallback_width
# Truncate the string first
truncated = text[:width]
# Pad the truncated string with spaces to the full width
return truncated.ljust(width)
[docs]
def get_separator_line(char='-', fallback_width=100):
"""
Generates a separator line spanning the terminal width.
Parameters
----------
char : str, optional
The character to repeat for the line, by default '-'.
fallback_width : int, optional
The width to use if terminal size can't be determined, by default 100.
Returns
-------
str
A string consisting of 'char' repeated to fill the terminal width.
"""
try:
# Attempt to get the terminal size
columns = shutil.get_terminal_size().columns
# Use the determined width, but ensure it's at least fallback_width
# to avoid overly short separators on very narrow terminals.
width = max(columns, fallback_width // 2) # Ensure a minimum reasonable width
except OSError:
# If output is redirected or terminal size unavailable, use fallback
width = fallback_width
return char * width
# Function to clear the current line before printing
[docs]
def clear_line():
"""
Clear the current terminal line for clean console output using ANSI escape codes.
Notes
-----
Uses carriage return and the ANSI sequence `\033[2K` to erase the entire line.
This is more reliable than printing spaces across different terminal widths.
"""
sys.stdout.write("\r\033[2K") # Move to beginning, erase entire line
sys.stdout.flush()
# Hide/show cursor functions for progress displays
[docs]
def hide_cursor():
"""
Hide the terminal cursor.
Notes
-----
Uses ANSI escape code to hide the cursor for cleaner progress bar display.
"""
sys.stdout.write(HIDE_CURSOR)
sys.stdout.flush()
[docs]
def show_cursor():
"""
Show the terminal cursor.
Notes
-----
Uses ANSI escape code to restore cursor visibility after it was hidden.
"""
sys.stdout.write(SHOW_CURSOR)
sys.stdout.flush()
# Safety mechanism to restore cursor if process is terminated unexpectedly
[docs]
def ensure_cursor_visible():
"""
Restore cursor visibility when program exits.
Notes
-----
Safety mechanism registered with atexit to ensure the terminal cursor
remains visible even if the program terminates unexpectedly.
"""
show_cursor()
# Register the safety function to run on exit
atexit.register(ensure_cursor_visible)
# Set up the logger
[docs]
def setup_logging():
"""
Configure and initialize the logging system.
Returns
-------
logging.Logger
Configured logger instance for nsphere_plot.py
Notes
-----
Creates a log directory if it doesn't exist and configures
a file-based logger that writes to log/nsphere_plot.log with
timestamp, log level, and message formatting.
"""
# Create log directory if it doesn't exist
os.makedirs("log", exist_ok=True)
# Configure the logger
log_file = "log/nsphere_plot.log"
logging.basicConfig(
filename=log_file,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
return logging.getLogger('nsphere_plot')
[docs]
def log_message(message, level="info"):
"""
Log a message to the log file based on level and settings.
Parameters
----------
message : str
The message to log
level : str, optional
The log level (info, warning, error, debug), by default "info"
Notes
-----
Warning and error messages are always logged regardless of enable_logging flag.
Info and debug messages are only logged if enable_logging is True (--log specified).
"""
global enable_logging
# Always log warnings and errors regardless of enable_logging setting
if level.lower() == "warning":
logger.warning(message)
return
elif level.lower() == "error":
logger.error(message)
return
# For info and debug, only log if enable_logging is True
if not enable_logging:
return
if level.lower() == "info":
logger.info(message)
elif level.lower() == "debug":
logger.debug(message)
else:
# Default to info level
logger.info(message)
# Initialize the logger
logger = setup_logging()
# Log plot saving to file, display on console cycling through filenames
[docs]
def log_plot_saved(output_file, current=0, total=1, section_type="plots"):
"""
Log a saved plot to log file and update status on console.
Parameters
----------
output_file : str
Full path to the saved file
current : int, optional
Current plot number, by default 0
total : int, optional
Total plots to save, by default 1
section_type : str, optional
Type of plots being saved, by default "plots"
Notes
-----
Displays a progress bar and updates the console with the current plot
being saved. Includes timing information once enough plots have been
processed to calculate a reasonable rate.
"""
# Log the full path to the log file
logger.info(f"Plot saved: {output_file}")
# First plot initialization is handled globally
if current == 1:
pass
# Show shorter message on console with cycling display
prefix = get_file_prefix(output_file)
clear_line()
# Create progress display
progress = (current / total) * 100 if total > 0 else 100
bar_length = 20
filled_length = int(bar_length * current // total)
bar = '█' * filled_length + ' ' * (bar_length - filled_length)
# Calculate time estimates for multiple plots
time_info = ""
if hasattr(log_plot_saved, 'start_time') and total > 1:
elapsed = time.time() - log_plot_saved.start_time
# For display purposes, ensure elapsed time is at least 0.01 seconds
displayed_elapsed = max(elapsed, 0.01) # Minimum displayed time
# Apply a minimum nominal time to prevent unrealistically high rates
nominal_elapsed = max(elapsed, 0.02) # Minimum 0.02 seconds
if current > 1: # Need at least 2 plots to calculate rate
rate = current / nominal_elapsed # plots per second
# Cap the rate at a reasonable maximum for display
rate = min(rate, 99.9)
# Only show remaining time once there's a reasonable rate calculation
if current >= max(3, total * 0.1): # At least 3 plots or 10% complete
remaining = (total - current) / rate if rate > 0 else 0
time_info = f" [{displayed_elapsed:.2f}<{remaining:.2f}, {rate:.1f}file/s]"
else:
# Initialize the start time on first call
log_plot_saved.start_time = time.time()
# Print the progress bar and current plot name inline with extra padding
# Truncate the content part of the string
content_string = f"Save: {bar} {progress:.1f}% | File: {prefix}{time_info}"
sys.stdout.write(f"\r{truncate_and_pad_string(content_string)}")
sys.stdout.flush()
# Print newline if this is the last plot
if current == total:
sys.stdout.write("\n")
# Add a separator line after a completed progress bar
sys.stdout.write(get_separator_line(char='-') + "\n")
sys.stdout.flush()
# Add delay after completing progress section if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
# Reset start time for next batch
if hasattr(log_plot_saved, 'start_time'):
delattr(log_plot_saved, 'start_time')
# Global state for combined progress tracking
_combined_plot_trackers = {}
[docs]
def start_combined_progress(section_key, total_plots):
"""
Initialize a combined progress tracker for a section of related plots.
Parameters
----------
section_key : str
Unique identifier for this plotting section
total_plots : int
Total number of plots expected in this section
Notes
-----
Sets up a progress tracker in the global _combined_plot_trackers dictionary
to monitor progress across multiple related operations. Initializes timing
information for rate calculation.
"""
_combined_plot_trackers[section_key] = {
'current': 0,
'total': total_plots,
'plots': [],
'start_time': time.time() # Initialize timing information
}
logger.info(f"Starting combined progress tracking for {section_key} with {total_plots} total plots")
[docs]
def update_combined_progress(section_key, output_file):
"""
Update the combined progress tracker for a specific plot section.
Parameters
----------
section_key : str
Identifier for the plotting section
output_file : str
Path to the plot file that was saved
Returns
-------
tuple
(current, total) counts for the section, indicating progress
Notes
-----
Updates the progress tracker for the specified section, increments
the counter, and displays a progress bar showing completion status.
Includes timing information and estimated completion time once
enough data points are available.
"""
if section_key not in _combined_plot_trackers:
logger.warning(f"No active progress tracker for {section_key}")
return (0, 0)
tracker = _combined_plot_trackers[section_key]
tracker['current'] += 1
tracker['plots'].append(output_file)
logger.info(f"Progress update: {output_file}")
prefix = get_file_prefix(output_file)
clear_line()
progress = (tracker['current'] / tracker['total']) * 100 if tracker['total'] > 0 else 100
bar_length = 20
filled_length = int(bar_length * tracker['current'] // tracker['total'])
bar = '█' * filled_length + ' ' * (bar_length - filled_length)
# Calculate time estimates
time_info = ""
if 'start_time' in tracker:
elapsed = time.time() - tracker['start_time']
# For display purposes, ensure elapsed time is at least 0.01 seconds
displayed_elapsed = max(elapsed, 0.01) # Minimum displayed time
# Apply a minimum nominal time to prevent unrealistically high rates
nominal_elapsed = max(elapsed, 0.02) # Minimum 0.02 seconds
# Always show timing information, even for single items
# Calculate a reasonable rate based on progress so far
rate = tracker['current'] / nominal_elapsed # items per second
# Cap the rate at a reasonable maximum for display
rate = min(rate, 99.9)
# Calculate remaining time based on current rate
remaining = (tracker['total'] - tracker['current']) / rate if rate > 0 else 0
# Always show timing information
time_info = f" [{displayed_elapsed:.2f}<{remaining:.2f}, {rate:.1f}file/s]"
else:
# Initialize the start time on first update
tracker['start_time'] = time.time()
# Determine if this is a loading progress or saving progress
action = "Loading data" if section_key.endswith("_data_loading") else "Saving plots"
# Print the progress bar and current plot name inline with extra padding
# For progress bars, use shorter Read/Save labels
display_action = "Read:" if section_key.endswith("_data_loading") else "Save:"
# Truncate the content part of the string
content_string = f"{display_action} {bar} {progress:.1f}% | File: {prefix}{time_info}"
sys.stdout.write(f"\r{truncate_and_pad_string(content_string)}")
sys.stdout.flush()
# Print newline if this is the last plot
if tracker['current'] >= tracker['total']:
sys.stdout.write("\n")
# Add a separator line after a completed progress bar
sys.stdout.write(get_separator_line(char='-') + "\n")
sys.stdout.flush()
# Add delay after completing progress section if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
# Clear the tracker as it's complete
_combined_plot_trackers.pop(section_key, None)
return (tracker['current'], tracker['total'])
# Custom print functions for consistent output formatting
[docs]
def print_status(message):
"""
Print a message to the console and also log it if detailed logging is enabled.
Parameters
----------
message : str
The message to print and potentially log.
"""
# Always print to console (if not showing help)
if not showing_help:
clear_line()
# Truncate the message before writing
sys.stdout.write(truncate_and_pad_string(message) + "\n")
# Also log the message if detailed logging is enabled
if enable_logging:
logger.info(f"[CONSOLE] {message}")
[docs]
def read_lastparams(filename="data/lastparams.dat", return_suffix=True, user_suffix=None):
"""
Read parameters from lastparams.dat file.
Parameters
----------
filename : str, optional
Path to the lastparams.dat file. Default is "data/lastparams.dat".
return_suffix : bool, optional
If True, return a suffix string. If False, return individual parameters.
Default is True.
user_suffix : str, optional
If provided, look for lastparams_user_suffix.dat instead of lastparams.dat.
Returns
-------
str or tuple
If return_suffix is True, returns a string like "_[filetag]_npts_Ntimes_tfinal_factor"
or just "_npts_Ntimes_tfinal_factor" if no file tag is present.
If return_suffix is False, returns a tuple of (npts, Ntimes, tfinal_factor, file_tag).
"""
global showing_help
default_suffix = ""
default_params = (30000, 1001, 5, "") # Default values for (npts, Ntimes, tfinal_factor, file_tag)
# If showing help, return default values without file operations
if showing_help:
return default_suffix if return_suffix else default_params
# Always use data/lastparams.dat unless a specific suffix was provided
if user_suffix:
# If user supplied a suffix, use lastparams_suffix.dat
suffix_to_use = user_suffix.lstrip('_') # Remove leading underscore if present
filename = f"data/lastparams_{suffix_to_use}.dat"
logger.info(f"Using user-specified suffix: {suffix_to_use}, looking for file: {filename}")
else:
# Otherwise use lastparams.dat
filename = "data/lastparams.dat"
logger.info(f"Using default lastparams.dat file")
if not os.path.exists(filename):
print(truncate_and_pad_string(f"Error: Could not find {filename}, and no suffixed lastparams files were found."))
print(truncate_and_pad_string("Please run the simulation before plotting."))
print(truncate_and_pad_string("Alternatively, specify parameters using the --suffix command line argument."))
sys.exit(1)
try:
with open(filename, "r") as f:
line = f.readline().strip()
if not line:
return default_suffix if return_suffix else default_params
parts = line.split(maxsplit=3) # Split into at most 4 parts to handle file tag with spaces
if len(parts) < 3:
return default_suffix if return_suffix else default_params
npts = int(parts[0])
Ntimes = int(parts[1])
tfinal_factor = int(parts[2])
# Get file tag if available (might be empty)
file_tag = parts[3] if len(parts) > 3 else ""
# Log the file tag found in the file
logger.info(f"File tag from lastparams file: '{file_tag}'")
if return_suffix:
# Build suffix based on file tag
if file_tag:
return f"_{file_tag}_{npts}_{Ntimes}_{tfinal_factor}"
else:
return f"_{npts}_{Ntimes}_{tfinal_factor}"
else:
return (npts, Ntimes, tfinal_factor, file_tag)
except Exception as e:
print(truncate_and_pad_string(f"Warning: Error reading {filename}: {e}"))
return default_suffix if return_suffix else default_params
# Default values for module-level constants
# These will be properly initialized in main() via Configuration
npts, Ntimes, tfinal_factor, file_tag = 30000, 1001, 5, ""
# Define column counts for various data structures
ncol_traj_particles = 10
nlowest = 5
ncol_convergence = 2
ncol_debug_energy_compare = 8
ncol_particles_dat = 5
ncol_particlesfinal = 4
ncol_density_profile = 2
ncol_mass_profile = 2
ncol_psi_profile = 2
ncol_psi_theory = 3
ncol_dpsi_dr = 2
ncol_drho_dpsi = 2
ncol_f_of_E = 2
ncol_df_fixed_radius = 2
ncol_combined_histogram = 3
ncol_integrand = 2
ncol_particles_initial = 4
ncol_Rank_Mass_Rad_VRad_unsorted = 7
# Constants for unit conversions
kmsec_to_kpcmyr = 1.02271e-3 # Conversion factor from km/s to kpc/Myr
[docs]
def filter_finite_rows(*arrs):
"""
Filter out rows containing NaN or Inf values across multiple arrays.
Parameters
----------
*arrs : array-like
Variable number of arrays to filter. All must have the same length.
Returns
-------
list
New arrays with invalid rows removed from all input arrays.
Notes
-----
This function removes any row that contains NaN or Inf in ANY of the input arrays.
All arrays must have the same length for proper row-wise filtering.
"""
if not arrs:
return arrs
length = len(arrs[0])
mask = np.ones(length, dtype=bool)
for arr in arrs:
if len(arr) != length:
raise ValueError(
"Arrays must have the same length to apply filter_finite_rows.")
mask &= np.isfinite(arr)
return [arr[mask] for arr in arrs]
# Additional column counts derived from the base counts
ncol_trajectories = 1 + 3 * ncol_traj_particles
ncol_single_trajectory = 4
ncol_energy_and_angular_momentum_vs_time = 1 + 4 * ncol_traj_particles
ncol_lowest_l_trajectories = 1 + 3 * nlowest
ncol_2d_hist_initial = 3
ncol_2d_hist_final = 3
ncol_all_particle_data_snapshot = 3
ncol_Rank_Mass_Rad_VRad_unsorted = 7
ncol_Rank_Mass_Rad_VRad_sorted = 8
# Parameter information is displayed in the main function header
[docs]
def get_mem_usage_mb():
"""
Get current memory usage of the Python process.
Returns
-------
float
Current Resident Set Size (RSS) memory usage in megabytes.
Notes
-----
Provides the actual physical memory being used by the Python process
using the psutil library to access system resource information.
"""
proc = psutil.Process(os.getpid())
return proc.memory_info().rss / (1024.0 * 1024.0)
[docs]
def update_timer_energy_plots(stop_timer, energy_plot_start_time, paced_mode):
"""
Continuously update the timer display for energy plot processing.
Parameters
----------
stop_timer : threading.Event
Event to signal when to stop the timer thread
energy_plot_start_time : float
Time when energy plot processing started
paced_mode : bool
Whether the visualization is in paced mode (with deliberate delays)
Notes
-----
Runs in a separate thread and updates the console with elapsed time,
estimated remaining time, and processing rate. Updates every 0.1 seconds
until stop_timer is set.
"""
# Only wait briefly to let the initial display settle
if paced_mode:
# In paced mode, use longer delay for better visualization
time.sleep(0.5)
else:
# In fast mode (default), use minimal delay
time.sleep(0.01)
while not stop_timer.is_set():
cur_elapsed = time.time() - energy_plot_start_time
# Apply a minimum nominal time to prevent unrealistically high rates
# For display purposes, ensure elapsed time is at least 0.01 seconds
displayed_elapsed = max(cur_elapsed, 0.01) # Minimum displayed time
nominal_elapsed = max(cur_elapsed, 0.02) # Minimum for rate calculation
est_remaining = nominal_elapsed # Estimate remaining time based on elapsed
cur_rate = 0.5 / nominal_elapsed # 0.5 plots in the elapsed time
# Cap the rate at a reasonable maximum for display
cur_rate = min(cur_rate, 99.9)
# Format the current timing info with extra padding spaces
time_info = f" [{displayed_elapsed:.2f}<{est_remaining:.2f}, {cur_rate:.1f}file/s]"
# Print the progress bar and current plot name inline with extra padding
bar_length = 20
half_filled = int(bar_length * 0.5)
half_bar = '█' * half_filled + ' ' * (bar_length - half_filled)
# Get current file prefix
unsorted_name = "Energy_vs_timestep_unsorted" # Base name without suffix
prefix = get_file_prefix(unsorted_name)
# Use ANSI escape sequences to update progress in-place
# Truncate the content part of the string
content_string = f"Save: {half_bar} 50.0% | File: {prefix}{time_info}"
sys.stdout.write(f"\r\033[2K{truncate_and_pad_string(content_string)}")
sys.stdout.flush()
# Sleep briefly to avoid high CPU usage
if paced_mode:
# In paced mode, use longer delay for better visualization
time.sleep(0.2)
else:
# In fast mode (default), use minimal delay
time.sleep(0.02)
[docs]
def get_snapshot_number(filename, pattern=None):
"""
Extract the snapshot number from a filename containing a timestamp pattern.
Parameters
----------
filename : str
Filename to extract snapshot number from
pattern : re.Pattern, optional
Compiled regex pattern to use. If None, defaults to Rank_Mass_Rad_VRad_sorted_t pattern.
Returns
-------
int
Snapshot number extracted from filename, or a very large number if no match found
Notes
-----
Uses regular expression matching to extract the snapshot number from filenames
that follow a pattern like "Rank_Mass_Rad_VRad_sorted_t00012". Returns a large
value for non-matching files to ensure they sort after valid snapshot files.
"""
if pattern is None:
pattern = re.compile(r'Rank_Mass_Rad_VRad_sorted_t(\d+)')
match = pattern.search(filename)
if match:
return int(match.group(1))
return 999999999 # Default high value for non-matching files
[docs]
def get_file_prefix(filepath):
"""
Extract the filename prefix without path, extension or suffix.
Parameters
----------
filepath : str
Full path or filename to process
Returns
-------
str
The extracted filename prefix
Examples
--------
>>> get_file_prefix('results/phase_space_initial_kris_40000_1001_5.png')
'phase_space_initial'
>>> get_file_prefix('loading_combined_histogram')
'combined_histogram'
Notes
-----
Handles special cases including loading prefixes and various suffix
patterns. For files with a standard suffix pattern (_tag_nnn_nnn_n),
removes these components to extract the core filename.
"""
# Make sure filepath is a string (handle None or other types)
if not isinstance(filepath, str):
return "unknown"
# Special handling for "loading_" identifiers
if filepath.startswith("loading_"):
return filepath[8:] # Remove "loading_" prefix
# Get the filename without the path
filename = os.path.basename(filepath)
# Remove the extension
prefix = os.path.splitext(filename)[0]
# For files with the suffix, remove the suffix part
parts = prefix.split('_')
# Handle potential energy_compare filename change
if prefix.startswith("energy_compare") and not prefix.startswith("debug_energy_compare"):
prefix = prefix.replace("energy_compare", "debug_energy_compare", 1)
# If the filename has a suffix pattern _tag_nnn_nnn_n
if len(parts) > 3 and all(part.isdigit() for part in parts[-3:]):
# Remove the last parts that match the suffix pattern
# For files with the pattern name_tag_npts_Ntimes_tfinal_factor
if len(parts) > 3:
# Keep everything before the suffix
# Re-check prefix after potential rename
prefix_parts = prefix.split('_')
return '_'.join(prefix_parts[:-4] if prefix_parts[-4].isalpha() else prefix_parts[:-3])
# If the file path looks like "loading_something_convergence_data" or other non-filename format
# (used for progress updates), just return the basename without special processing
if "loading_" in prefix or "processing_" in prefix:
return prefix
return prefix
[docs]
def load_partial_file(filename, dtype=np.float32):
"""
Load data from file, handling incomplete or partially-written lines.
Parameters
----------
filename : str
Path to the file to read
dtype : numpy.dtype, optional
Data type to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Array containing the data successfully read from the file, or None if
the file doesn't exist or contains no valid data
Notes
-----
The function stops reading if it encounters a line it cannot parse,
returning all data successfully read up to that point. This is useful
for handling files that might be incomplete or partially written.
Skips blank lines and stops reading at the first parsing error.
"""
# Skip if file doesn't exist
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} not found. Skipping."))
return None
valid_rows = []
with open(filename, 'r') as f:
for line in f:
stripped = line.strip()
if not stripped:
# skip blank lines
continue
try:
row = np.fromstring(stripped, sep=' ', dtype=dtype)
except ValueError:
# If there's a parse error (e.g., incomplete line),
# Stop reading further
break
if len(row) == 0:
# If line was empty or couldn't parse anything, stop
continue
valid_rows.append(row)
return np.array(valid_rows, dtype=dtype)
[docs]
def load_partial_file_bin(filename, ncols, dtype=np.float32):
"""
Load data from binary file with support for different data types.
Parameters
----------
filename : str
Path to the binary file to read
ncols : int
Number of columns expected in the data
dtype : numpy.dtype or list, optional
Data type(s) to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Array containing the data from the file, or None if the file
is missing or contains no valid data
Notes
-----
Supports two different modes of operation:
1. Single dtype (e.g., np.float32):
Returns array of shape (N, ncols)
2. List/tuple of dtypes:
Returns structured array with fields named 'f0', 'f1', ..., 'f{ncols-1}',
each with its corresponding dtype from the list
Uses align=False to ensure tight packing of structured arrays.
Reads as many complete rows as possible from the file.
"""
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} not found. Skipping."))
return None
# Get file size for diagnostics
file_size = os.path.getsize(filename)
# Log file size for very small files (<1000 bytes) to help debug issues
if file_size < 1000:
logger.info(f"Small binary file detected: {filename}, size: {file_size} bytes")
with open(filename, 'rb') as f:
raw = f.read()
if not raw:
logger.warning(f"Empty binary file: {filename}, size: {file_size} bytes")
return None
# Check if dtype is a list of dtypes => build structured dtype
if isinstance(dtype, (list, tuple)):
if len(dtype) != ncols:
logger.error(f"ERROR: dtype list length {len(dtype)} != ncols={ncols}")
return None
# Build fields: [("f0", dtype[0]), ("f1", dtype[1]), ...]
fields = []
for i in range(ncols):
fields.append((f"f{i}", dtype[i]))
structured = np.dtype(fields, align=False)
itemsize = structured.itemsize
nrows = len(raw) // itemsize
# Log extra diagnostics for small files
if nrows < 100:
logger.info(f"Binary file has few rows: {filename}, raw bytes: {len(raw)}, itemsize: {itemsize}, estimated rows: {nrows}")
if nrows == 0:
logger.warning(f"No complete rows found in binary file: {filename}, raw bytes: {len(raw)}, itemsize: {itemsize}")
# For very small files, log hex dump of raw bytes to help debug
if len(raw) < 100:
logger.info(f"Raw binary data (hex): {' '.join(f'{b:02x}' for b in raw[:100])}")
return None
# Drop leftover bytes
bytes_to_use = nrows * itemsize
if len(raw) != bytes_to_use:
logger.warning(f"Partial final row detected: {filename}, using {bytes_to_use} of {len(raw)} bytes")
raw = raw[:bytes_to_use]
arr = np.frombuffer(raw, dtype=structured)
gc.collect()
# For debugging small datasets, log some content
if nrows < 10:
logger.info(f"Structured array content (first {min(nrows, 5)} rows): {arr[:5] if nrows > 0 else 'empty'}")
return arr
else:
# Normal single-dtype path
arr = np.frombuffer(raw, dtype=dtype)
full_count = arr.size // ncols
# Log extra diagnostics for small arrays
if full_count < 100:
logger.info(f"Binary file has few rows: {filename}, array size: {arr.size}, columns: {ncols}, complete rows: {full_count}")
if full_count == 0:
logger.warning(f"No complete rows found in binary file: {filename}, array size: {arr.size}, columns: {ncols}")
# For very small arrays, log raw values to help debug
if arr.size < 100:
logger.info(f"Raw array values: {arr}")
return None
arr = arr[:full_count * ncols]
arr = arr.reshape((full_count, ncols))
gc.collect()
# For debugging small datasets, log some content
if full_count < 10:
logger.info(f"Array content (first {min(full_count, 5)} rows): {arr[:5] if full_count > 0 else 'empty'}")
return arr
[docs]
def safe_load_and_filter(filename, dtype=np.float32):
"""
Load and filter data from file with error handling.
Parameters
----------
filename : str
Path to the file to read
dtype : numpy.dtype, optional
Data type to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Filtered array containing only valid data, or None if
the file doesn't exist or contains no valid data
Notes
-----
This function performs three steps:
1. Verifies the file exists
2. Loads as much data as possible using load_partial_file
3. Filters out rows containing NaN or Inf values
Returns None with appropriate warnings if the file is missing or
contains no valid data after filtering.
"""
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} does not exist. Skipping."))
return None
data = load_partial_file(filename, dtype=dtype)
gc.collect()
if data is None or data.size == 0:
# Means either the file was missing or empty or partial lines only
print(truncate_and_pad_string(f"WARNING: {filename} had no valid data. Skipping."))
return None
# Filter out any invalid (NaN/Inf) rows
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
if data.size == 0:
print(
truncate_and_pad_string(f"WARNING: After filtering, {filename} had no valid rows. Skipping."))
return None
return data
[docs]
def safe_load_and_filter_bin(filename, ncols, dtype=np.float32):
"""
Load and filter binary data with support for structured arrays.
Parameters
----------
filename : str
Path to the binary file to read
ncols : int
Number of columns expected in the data
dtype : numpy.dtype or list, optional
Data type(s) to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
2D float array containing only valid data, or None if
the file doesn't exist or contains no valid data
Notes
-----
This function performs four steps:
1. Verifies the file exists
2. Loads binary data using load_partial_file_bin
3. Converts structured arrays to standard floating-point arrays if needed
4. Filters out rows containing NaN or Inf values
For structured arrays (when dtype is a list), each field is extracted and
converted to float32 before filtering. This ensures consistent handling
regardless of the original data types.
"""
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} does not exist. Skipping."))
return None
data = load_partial_file_bin(filename, ncols=ncols, dtype=dtype)
gc.collect()
if data is None or data.size == 0:
print(truncate_and_pad_string(f"WARNING: {filename} had no valid data. Skipping."))
return None
# Handle structured arrays (from list dtype) differently
if isinstance(dtype, (list, tuple)):
# Convert structured array to a regular 2D array with all fields as float32
nrows = data.shape[0]
float_data = np.zeros((nrows, ncols), dtype=np.float32)
# Extract each field and convert to float
for i in range(ncols):
field_name = f"f{i}"
float_data[:, i] = data[field_name].astype(np.float32, copy=False)
# Replace structured array with regular float array
data = float_data
# Apply isfinite to the standard array
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
else:
# For regular arrays, just filter normally
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
if data.size == 0:
print(
truncate_and_pad_string(f"WARNING: After filtering, {filename} had no valid rows. Skipping."))
return None
return data
[docs]
def load_specific_columns_bin(filename, ncols_total, cols_to_load, dtype_list):
"""
Loads specific columns from a flat binary file using memory mapping.
This version uses numpy.memmap for efficiency with large files, reading
only necessary data pages from disk into memory when columns are accessed.
Parameters
----------
filename : str
Path to the binary file.
ncols_total : int
Total number of columns (fields) in each row of the file.
cols_to_load : list[int]
List of 0-based column indices to load and return.
dtype_list : list
List of numpy dtypes corresponding to *all* columns in the file,
in order. E.g., [np.int32, np.float32, np.float32, ...].
Returns
-------
list[np.ndarray] or None
A list containing 1D numpy arrays (as float32) for each requested
column, filtered for finite values based on the loaded columns.
Returns None if the file doesn't exist, cannot be mapped,
or contains no valid data in requested columns.
Notes
-----
Requires the `dtype_list` specifying the type for *all* columns
to correctly interpret the binary structure for memmap.
Returned arrays are converted to float32 for consistency.
"""
if not os.path.exists(filename):
logger.warning(f"Memmap loader: File not found {filename}")
return None
if not isinstance(dtype_list, list) or len(dtype_list) != ncols_total:
logger.error(f"Memmap loader: Invalid dtype_list provided for {filename}. Expected list of length {ncols_total}.")
return None
try:
# Define the structured dtype for the entire row
row_fields = [(f'f{i}', dtype) for i, dtype in enumerate(dtype_list)]
row_dtype = np.dtype(row_fields)
item_size = row_dtype.itemsize
file_size = os.path.getsize(filename)
if file_size < item_size:
logger.warning(f"Memmap loader: File {filename} is smaller than one row.")
return None
# Calculate number of rows based on file size and item size
num_rows = file_size // item_size
if num_rows == 0:
logger.warning(f"Memmap loader: No complete rows found in {filename}")
return None
# Memory-map the file
# mode='r' is read-only
mmap_array = np.memmap(filename, dtype=row_dtype, mode='r', shape=(num_rows,))
# Extract required columns using field names
# Create copies to bring data into memory and ensure float32
extracted_cols = []
col_names_to_load = [f'f{i}' for i in cols_to_load]
for col_name in col_names_to_load:
# Access column via field name, convert to float32 and copy into RAM
# Using copy() is crucial when the memmap will be closed.
extracted_cols.append(mmap_array[col_name].astype(np.float32).copy())
# Filter based on finiteness of all loaded columns *together*
# Stack columns temporarily for efficient masking
valid_cols_data = np.column_stack(extracted_cols)
finite_mask = np.all(np.isfinite(valid_cols_data), axis=1)
del valid_cols_data # Free temporary stack
gc.collect()
if not np.any(finite_mask):
logger.warning(f"Memmap loader: No finite rows for requested columns in {filename}")
# Clean up memmap object before returning
del mmap_array
gc.collect()
return None
# Apply mask to the list of extracted columns
filtered_cols = [col[finite_mask] for col in extracted_cols]
# Clean up memmap object explicitly (important!)
del mmap_array
gc.collect()
return filtered_cols
except Exception as e:
logger.error(f"Memmap loader: Error processing {filename}: {e}")
logger.error(traceback.format_exc())
# Ensure memmap is cleaned up if partially created
if 'mmap_array' in locals() and mmap_array._mmap is not None:
del mmap_array
gc.collect()
return None
[docs]
def load_partial_file_10(filename, dtype=np.float32):
"""
Load the first 10 lines from a text file.
Parameters
----------
filename : str
Path to the file to read
dtype : numpy.dtype, optional
Data type to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Array containing the data from the first 10 lines of the file,
or None if the file doesn't exist or contains no valid data
Notes
-----
Useful for quickly previewing or analyzing the beginning of a large file
without loading the entire dataset into memory. Only reads the first 10
lines that contain valid data.
"""
if not os.path.exists(filename):
return None
with open(filename, 'r') as f:
lines = []
for i, line in enumerate(f):
if i >= 10: # Only read 10 lines
break
lines.append(line)
if not lines:
return None
# Convert the lines to a numpy array
data = np.array([np.fromstring(line.strip(), sep=' ', dtype=dtype) for line in lines if line.strip()])
return data
[docs]
def safe_load_and_filter_10(filename, dtype=np.float32):
"""
Load and filter the first 10 lines from a text file.
Parameters
----------
filename : str
Path to the file to read
dtype : numpy.dtype, optional
Data type to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Filtered array containing only valid data from the first 10 lines,
or None if the file doesn't exist or contains no valid data
Notes
-----
Similar to safe_load_and_filter but only reads the first 10 lines.
Useful for analyzing the beginning of large files efficiently.
"""
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} does not exist. Skipping."))
return None
data = load_partial_file_10(filename, dtype=dtype)
gc.collect()
if data is None or data.size == 0:
print(truncate_and_pad_string(f"WARNING: {filename} had no valid data. Skipping."))
return None
# Filter out any invalid (NaN/Inf) rows
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
if data.size == 0:
print(
truncate_and_pad_string(f"WARNING: After filtering, {filename} had no valid rows. Skipping."))
return None
return data
[docs]
def safe_load_and_filter_10_bin(filename, ncols, dtype=np.float32):
"""
Load and filter the first 10 rows from a binary file.
Parameters
----------
filename : str
Path to the binary file to read
ncols : int
Number of columns expected in the data
dtype : numpy.dtype or list, optional
Data type(s) to use for the array, by default np.float32
Returns
-------
numpy.ndarray or None
Filtered array containing only valid data from the first 10 rows,
or None if the file doesn't exist or contains no valid data
Notes
-----
Performs the following steps:
1. Checks if file exists
2. Loads entire binary file via load_partial_file_bin
3. Slices off the first 10 rows (if available)
4. Filters out invalid rows (NaN/Inf)
5. Returns the filtered array, or None if empty/no valid rows
Useful for analyzing the beginning of large binary files efficiently.
"""
if not os.path.exists(filename):
print(truncate_and_pad_string(f"WARNING: {filename} does not exist. Skipping."))
return None
data = load_partial_file_bin(filename, ncols=ncols, dtype=dtype)
gc.collect()
if data is None or data.size == 0:
print(
truncate_and_pad_string(f"WARNING: {filename} had no valid data (or is empty). Skipping."))
return None
# Keep only the first 10 rows (or fewer if file has <10 rows)
max_needed = min(10, data.shape[0])
data = data[:max_needed, :]
# Filter out any invalid (NaN/Inf) rows
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
if data.size == 0:
print(
truncate_and_pad_string(f"WARNING: After filtering, {filename} had no valid rows. Skipping."))
return None
return data
[docs]
def safe_load_particle_ids_bin(filename, ncols, particle_ids, dtype=np.float32):
"""
Load data for specific particle IDs from a flat binary file.
Uses efficient file seeking for basic numpy dtypes (e.g., np.float32)
to read only requested rows. Falls back to loading the full file via
`load_partial_file_bin` for complex/structured dtypes.
Parameters
----------
filename : str
Path to the binary file (assumed flat row data).
ncols : int
Number of columns per row.
particle_ids : array-like
Array of particle IDs (0-based row indices) to extract.
dtype : numpy.dtype or list, optional
Data type of file contents. Basic numpy dtypes trigger optimized
seek-based reading. Lists/tuples trigger a full read and fallback
extraction. Default is np.float32.
Returns
-------
numpy.ndarray or None
Numpy array (float32) containing only the rows for the specified
particle IDs (NaNs for missing/invalid IDs), or None if file
doesn't exist or a fatal error occurs. Returns float32 array
even if input dtype has different precision.
"""
# Basic validation
if not os.path.exists(filename):
return None
# Check if we should use the optimized path
use_optimized_path = isinstance(dtype, type) and hasattr(dtype, 'itemsize')
if not use_optimized_path:
data = load_partial_file_bin(filename, ncols=ncols, dtype=dtype) # Existing full load
if data is None:
return None
# Convert structured to float32 if needed (logic from safe_load_and_filter_bin)
if isinstance(dtype, (list, tuple)):
nrows = data.shape[0]
float_data = np.zeros((nrows, ncols), dtype=np.float32)
for i in range(ncols):
field_name = f"f{i}"
float_data[:, i] = data[field_name].astype(np.float32, copy=False)
data = float_data
# Now extract rows using boolean indexing (less efficient than direct seek)
num_particles = len(particle_ids)
result_array = np.full((num_particles, ncols), np.nan, dtype=np.float32)
# Need unique IDs and sort order for efficient lookup if using fallback
unique_ids, original_indices = np.unique(particle_ids, return_inverse=True)
valid_mask = (unique_ids >= 0) & (unique_ids < data.shape[0])
found_ids = unique_ids[valid_mask]
if len(found_ids) > 0:
extracted_data = data[found_ids, :]
# Map back to original particle_ids order
mapping = {pid: i for i, pid in enumerate(found_ids)}
original_mask_for_found = valid_mask[original_indices]
result_indices = np.where(original_mask_for_found)[0]
source_indices = [mapping[pid] for pid in np.array(particle_ids)[original_mask_for_found]]
if len(result_indices) > 0:
result_array[result_indices] = extracted_data[source_indices]
return result_array
# --- Start: Optimized path for basic dtypes like float32 ---
try:
item_size = np.dtype(dtype).itemsize
row_size = ncols * item_size
file_size = os.path.getsize(filename)
max_rows = file_size // row_size
if row_size == 0 or max_rows == 0:
return None
num_requested = len(particle_ids)
# Ensure result array is float32, regardless of input dtype's precision
result_array = np.full((num_requested, ncols), np.nan, dtype=np.float32)
found_count = 0
with open(filename, 'rb') as f:
for i, pid in enumerate(particle_ids):
# Check if pid is a valid row index for this file
if 0 <= pid < max_rows:
offset = pid * row_size
f.seek(offset)
row_bytes = f.read(row_size)
if len(row_bytes) == row_size:
# Convert bytes to numpy array of the specified dtype
row_data = np.frombuffer(row_bytes, dtype=dtype, count=ncols)
# Store as float32
result_array[i, :] = row_data.astype(np.float32, copy=False)
found_count += 1
if found_count == 0:
return None
return result_array
except Exception as e:
return None
finally:
gc.collect()
[docs]
def plot_density(r_values, rho_values, output_file="density_profile.png"):
"""
Plot density profile as a function of radius and save to file.
Parameters
----------
r_values : array-like
Radius values in kpc
rho_values : array-like
Density values (rho) at each radius
output_file : str, optional
Path to save the output plot, by default "density_profile.png"
Returns
-------
str
Path to the saved plot file
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing density as a function of radius with
appropriate labels and grid.
"""
r_values, rho_values = filter_finite_rows(r_values, rho_values)
plt.figure(figsize=(10, 6))
plt.plot(r_values, rho_values, linewidth=2, label=r'$\rho(r)$')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$\rho(r)$ (M$_{\odot}$/kpc$^3$)', fontsize=12)
plt.title(r'Radial Density Profile $\rho(r)$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_mass_enclosed(r_values, mass_values, output_file="mass_enclosed.png"):
"""
Plot enclosed mass as a function of radius and save to file.
Parameters
----------
r_values : array-like
Radius values in kpc
mass_values : array-like
Enclosed mass values at each radius in Msun
output_file : str, optional
Path to save the output plot, by default "mass_enclosed.png"
Returns
-------
str
Path to the saved plot file
Notes
-----
Filters out any non-finite values before plotting. Creates a figure
showing the mass enclosed within each radius with appropriate labels and grid.
"""
r_values, mass_values = filter_finite_rows(r_values, mass_values)
plt.figure(figsize=(10, 6))
plt.plot(r_values, mass_values, label=r'$M(r)$')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$M(r)$ (M$_\odot$)', fontsize=12)
plt.title(r'Enclosed Mass Profile $M(r)$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_psi(r_values, psi_values, output_file="psi_profile.png"):
"""
Plot gravitational potential (Psi) as a function of radius and save to file.
Parameters
----------
r_values : array-like
Radius values in kpc
psi_values : array-like
Gravitational potential values at each radius (in km^2/s^2)
output_file : str, optional
Path to save the output plot, by default "psi_profile.png"
Returns
-------
None
Function does not return a value, but saves plot to the specified path
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing gravitational potential as a function of radius
with appropriate labels and grid. Assumes psi_values are already scaled.
"""
r_values, psi_values = filter_finite_rows(r_values, psi_values)
plt.figure(figsize=(10, 6))
plt.plot(r_values, psi_values, label=r'$\Psi(r)$')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$\Psi(r)$ (km$^2$/s$^2$)', fontsize=12)
plt.title(r'Gravitational Potential Profile $\Psi(r)$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
[docs]
def plot_dpsi_dr(r_values, dpsi_values, output_file="dpsi_dr.png"):
"""
Plot gravitational acceleration (dPsi/dr) as a function of radius and save to file.
Parameters
----------
r_values : array-like
Radius values in kpc
dpsi_values : array-like
Gravitational acceleration values at each radius
output_file : str, optional
Path to save the output plot, by default "dpsi_dr.png"
Returns
-------
None
Function does not return a value, but saves plot to the specified path
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing gravitational acceleration (dPsi/dr) as a
function of radius with appropriate labels and grid.
Units of dPsi/dr are typically (km/s)^2 / kpc.
"""
r_values, dpsi_values = filter_finite_rows(r_values, dpsi_values)
plt.figure(figsize=(10, 6))
plt.plot(r_values, dpsi_values, label=r'$d\Psi/dr$')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$d\Psi/dr$ ((km/s)$^2$/kpc)', fontsize=12)
plt.title(r'Gravitational Acceleration Profile $d\Psi/dr$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
[docs]
def plot_drho_dpsi(psi_values, drho_dpsi, output_file="drho_dpsi.png"):
"""
Plot derivative of density with respect to potential and save to file.
Parameters
----------
psi_values : array-like
Gravitational potential values (in km^2/s^2)
drho_dpsi : array-like
Derivative of density with respect to potential
output_file : str, optional
Path to save the output plot, by default "drho_dpsi.png"
Returns
-------
None
Function does not return a value, but saves plot to the specified path
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing the derivative of density with respect to
gravitational potential with appropriate labels and grid.
"""
psi_values, drho_dpsi = filter_finite_rows(psi_values, drho_dpsi)
plt.figure(figsize=(10, 6))
plt.plot(psi_values, drho_dpsi, label=r'$d\rho/d\Psi$')
plt.xlabel(r'$\Psi$ (km$^2$/s$^2$)', fontsize=12)
plt.ylabel(r'$d\rho/d\Psi$ ((M$_\odot$/kpc$^3$)/(km$^2$/s$^2$))', fontsize=12)
plt.title(r'Density Derivative $d\rho/d\Psi$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
[docs]
def plot_f_of_E(E_values, f_values, output_file="f_of_E.png"):
"""
Plot distribution function as a function of energy and save to file.
Parameters
----------
E_values : array-like
Energy values
f_values : array-like
Distribution function values at each energy
output_file : str, optional
Path to save the output plot, by default "f_of_E.png"
Returns
-------
str
Path to the saved plot file
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing the distribution function (f(E)) as a
function of energy with appropriate labels and grid.
"""
E_values, f_values = filter_finite_rows(E_values, f_values)
plt.figure(figsize=(10, 6))
plt.plot(E_values, f_values, linewidth=2, label=r'$f(\mathcal{E})$')
plt.xlabel(r'$\mathcal{E}$ (km$^2$/s$^2$)', fontsize=12)
plt.ylabel(r'$f(\mathcal{E})$ ((km/s)$^{-3}$ kpc$^{-3}$)', fontsize=12)
plt.title(r'Distribution Function $f(\mathcal{E})$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_df_at_fixed_radius(v_values, df_values, r_fixed, output_file="df_fixed_radius.png"):
"""
Plot the distribution function at a fixed radius.
Parameters
----------
v_values : array-like
Velocity values
df_values : array-like
Distribution function values
r_fixed : float
The fixed radius value in kpc
output_file : str, optional
Path to save the output plot, by default "df_fixed_radius.png"
Returns
-------
str
Path to the saved plot file
Notes
-----
Filters out any non-finite values before plotting.
Creates a figure showing the distribution function at a specific
fixed radius as a function of velocity.
"""
# Only log warnings for critical issues
if np.allclose(df_values, 0.0):
logger.warning("df_fixed_radius: All distribution function values are near zero")
# Only if all values are exactly 0, set a small value to prevent empty plot
if np.max(df_values) == 0:
df_values = np.maximum(df_values, 1e-10)
# Filter out non-finite values
v_values, df_values = filter_finite_rows(v_values, df_values)
# Check if arrays are empty after filtering
if len(v_values) == 0 or len(df_values) == 0:
logger.error("df_fixed_radius: No valid data points after filtering")
return None
# Use a simpler single-panel plot
plt.figure(figsize=(10, 6))
plt.plot(v_values, df_values, linewidth=2, label=r'$f$ at $r_{\mathrm{fixed}} = ' + f'{r_fixed:.1f}' + r'$ kpc')
plt.xlabel(r'$v$ (km/s)', fontsize=12)
plt.ylabel(r'$f(\mathcal{E},r)$ ((km/s)$^{-3}$ kpc$^{-3}$)', fontsize=12)
plt.title(r'Distribution Function at Fixed Radius ($r_{\mathrm{fixed}} = ' + f'{r_fixed:.1f}' + r'$ kpc)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=12)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_combined_histogram_from_file(input_file, output_file):
"""
Create an overlay histogram comparing initial and final radial distributions.
Parameters
----------
input_file : str
Path to the input data file containing combined histogram data
output_file : str
Path to save the output plot
Returns
-------
str
Path to the saved plot file
Notes
-----
The input file should contain three columns:
- bin_centers: The center values of histogram bins
- initial counts: Counts in each bin for the initial distribution
- final counts: Counts in each bin for the final distribution
The plot shows three categories using different colors:
- Overlap between initial and final (purple)
- Initial > Final (blue)
- Final > Initial (red)
This visualization helps identify which parts of the distribution
remained stable and which parts changed during the simulation.
"""
data = safe_load_and_filter_bin(input_file, ncol_combined_histogram, dtype=[
np.float32, np.int32, np.int32])
if data is None:
return None
gc.collect()
bin_centers = data[:, 0]
hist_iradii = data[:, 1]
hist_fradii = data[:, 2]
overlap = np.minimum(hist_iradii, hist_fradii)
initial_excess = hist_iradii - overlap
final_excess = hist_fradii - overlap
bin_width = bin_centers[1]-bin_centers[0] if len(bin_centers) > 1 else 1.0
plt.figure(figsize=(10, 6))
plt.bar(bin_centers, overlap, width=bin_width,
color='purple', label='Overlap', align='center')
plt.bar(bin_centers, initial_excess, width=bin_width, bottom=overlap,
color='blue', label='Initial > Final', align='center')
plt.bar(bin_centers, final_excess, width=bin_width, bottom=overlap,
color='red', label='Final > Initial', align='center')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$N(r)$', fontsize=12)
plt.title(r'Comparison of Initial and Final Radial Distributions $N(r)$', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_trajectories(input_file, output_file):
"""
Plot particle trajectories (radius vs time) for multiple particles.
Parameters
----------
input_file : str
Path to the input data file containing trajectory data.
output_file : str
Path to save the output plot.
Returns
-------
str
Path to the saved plot file
Notes
-----
The input file should contain column groups with time and radial position
for multiple particles over the simulation period.
"""
data = safe_load_and_filter_bin(
input_file, ncol_trajectories, dtype=np.float32)
if data is None:
return None
gc.collect()
time = data[:, 0]
plt.figure(figsize=(10, 6))
ncols = data.shape[1]
nparticles = (ncols - 1)//3
for p in range(nparticles):
r_col = 1 + 3*p
plt.plot(time, data[:, r_col], linewidth=1.5)
plt.xlabel(r'$t$ (Myr)', fontsize=12)
plt.ylabel(r'$r(t)$ (kpc)', fontsize=12)
plt.title(r'Particle Trajectories', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_single_trajectory(input_file, output_file):
"""
Plot the orbital trajectory (radius vs time) for a single test particle.
Parameters
----------
input_file : str
Path to the input data file containing trajectory data.
output_file : str
Path to save the output plot.
Returns
-------
str
Path to the saved plot file
Notes
-----
The input file should contain columns with time and radial position.
"""
data = safe_load_and_filter_bin(
input_file, ncol_single_trajectory, dtype=np.float32)
if data is None:
return None
gc.collect()
time = data[:, 0]
radius = data[:, 1]
plt.figure(figsize=(10, 6))
plt.plot(time, radius, linewidth=2)
plt.xlabel(r'$t$ (Myr)', fontsize=12)
plt.ylabel(r'$r(t)$ (kpc)', fontsize=12)
plt.title(r'Single Particle Trajectory', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_energy_time(input_file, output_file):
"""
Plot energy vs. time for multiple particles, illustrating integration stability.
NOTE: 'Current' energy uses the initial static potential for evaluation.
Parameters
----------
input_file : str
Path to the input data file containing energy data.
output_file : str
Path to save the output plot.
Returns
-------
str
Path to the saved plot file
Notes
-----
The input file should contain time in the first column, followed by repeated groups
of energy values (current energy evaluated with static potential, initial energy)
for each particle. Loads full data.
"""
# Load full data using the binary safe loader
data = safe_load_and_filter_bin(
input_file, ncol_energy_and_angular_momentum_vs_time, dtype=np.float32)
if data is None:
# Log warning if data loading failed
logger.warning(f"Failed to load or filter data from {input_file} for energy plot.")
return None
gc.collect()
time = data[:, 0]
ncols = data.shape[1]
nparticles = (ncols - 1) // 4 # Calculate number of particles based on columns
# Check if there are enough columns for at least one particle
if nparticles <= 0:
logger.warning(f"No particle data found in {input_file} after calculating columns.")
return None
fig, ax = plt.subplots(figsize=(10, 6)) # Get axis handle
for p in range(nparticles):
E_col = 1 + 4*p # Column index for current energy (using static Psi)
E_i_col = 2 + 4*p # Column index for initial energy
# Ensure column indices are valid
if E_i_col >= ncols:
logger.warning(f"Column index {E_i_col} out of bounds for particle {p+1} in {input_file}.")
continue # Skip this particle if indices are invalid
# Plot current energy (using static Psi) and initial energy
# Updated label to clarify E(t) uses static potential
ax.plot(time, data[:, E_col], linewidth=1.5, label=fr'$\mathcal{{E}}(t)$ (Static $\Psi_0$) P{p+1}')
ax.plot(time, data[:, E_i_col], '--', linewidth=1.5, label=fr'$\mathcal{{E}}_0$ P{p+1}')
# Use LaTeX for axis labels
ax.set_xlabel(r'$t$ (Myr)', fontsize=12)
ax.set_ylabel(r'$\mathcal{E}$ (km$^2$/s$^2$)', fontsize=12)
ax.set_title(r'Integration Stability: Energy Conservation Test', fontsize=14)
# Add legend, adjust size and location as needed
# Reduced font size slightly to avoid potential overlap with text box
ax.legend(fontsize=9, loc='upper right')
ax.grid(True, linestyle='--', alpha=0.7)
# Add explanatory text
explanation = (
"Tests numerical integration stability.\n"
"Solid: Energy evaluated at current state (r, v, L) using the *initial static potential*.\n"
"Dashed: Initial energy.\n"
"Note: This is not true energy conservation if the potential evolves over time."
)
# Place text box in bottom left, slightly offset from axes
ax.text(0.02, 0.02, explanation, transform=ax.transAxes, fontsize=8,
verticalalignment='bottom', bbox=dict(boxstyle='round,pad=0.3', fc='aliceblue', alpha=0.8))
# Adjust layout slightly to prevent xlabel overlap with the text box
plt.tight_layout(rect=[0, 0.05, 1, 1]) # Added bottom margin (rect=[left, bottom, right, top])
plt.savefig(output_file, dpi=200)
plt.close(fig) # Close the specific figure
log_plot_saved(output_file)
return output_file
[docs]
def plot_angular_momentum_time(input_file, output_file):
"""
Plot angular momentum vs. time for multiple particles, showing both current and initial values.
Parameters
----------
input_file : str
Path to the input data file containing angular momentum data.
output_file : str
Path to save the output plot.
Returns
-------
str
Path to the saved plot file
Notes
-----
The input file should contain time in the first column, followed by repeated groups
of values (current energy, initial energy, current angular momentum, initial angular momentum)
for each particle.
Loads full data.
"""
data = safe_load_and_filter_bin(
input_file, ncol_energy_and_angular_momentum_vs_time, dtype=np.float32)
if data is None:
return None
gc.collect()
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
time = data[:, 0]
ncols = data.shape[1]
nparticles = (ncols - 1)//4
plt.figure(figsize=(10, 6))
for p in range(nparticles):
L_col = 3 + 4*p
L_i_col = 4 + 4*p
# Use LaTeX in legend (using \ell)
plt.plot(time, data[:, L_col], linewidth=1.5, label=fr'$\ell(t)$ P{p+1}')
plt.plot(time, data[:, L_i_col], '--', linewidth=1.5, label=fr'$\ell_i$ P{p+1}')
# Use LaTeX for axis labels
plt.xlabel(r'$t$ (Myr)', fontsize=12)
plt.ylabel(r'$\ell$ (kpc$\cdot$km/s)', fontsize=12)
plt.title(r'Angular Momentum Conservation $\ell(t)$ vs $\ell_i$', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=200)
plt.close()
log_plot_saved(output_file)
return output_file
[docs]
def plot_lowestL_trajectories_3panel(input_file="lowest_l_trajectories.dat", output_file="lowestl_3panel.png"):
"""
Create a 3-panel plot showing trajectories of particles with lowest angular momentum.
Parameters
----------
input_file : str, optional
Path to the input data file containing trajectory data.
Default is "lowest_l_trajectories.dat".
output_file : str, optional
Path to save the output plot. Default is "lowestl_3panel.png".
Returns
-------
None
Notes
-----
The input file should have columns in the format:
time r1 E1 L1 r2 E2 L2 ... rN EN LN
The function creates a 3-panel plot showing:
1) Radii vs. time for each particle
2) Energies vs. time for each particle
3) Percentage deviation from average energy vs. time for each particle
"""
# Load the data
data = safe_load_and_filter_bin(
input_file, ncol_lowest_l_trajectories, dtype=np.float32)
if data is None:
return
gc.collect()
# Filter out rows with NaN/Inf
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
# time = first column
time = data[:, 0]
ncols = data.shape[1]
# each lowest-L particle contributes 3 columns: (r, E, L)
# so number of lowest-L particles = (ncols - 1)//3
nlowest = (ncols - 1) // 3
# Prepare the figure with 3 side-by-side subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
ax_r = axes[0]
ax_e = axes[1]
ax_dev = axes[2]
# 1) Radii vs. time
for p in range(nlowest):
r_col = 1 + 3*p # r is the first of each triplet
ax_r.plot(time, data[:, r_col], label=f"Particle {p+1}")
ax_r.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_r.set_ylabel(r"$r$ (kpc)", fontsize=12)
ax_r.set_title(r"Radius vs. Time (Lowest $\ell$)", fontsize=14)
ax_r.legend(fontsize=10)
ax_r.grid(True)
# 2) Energies vs. time
# second of each triplet is the energy
for p in range(nlowest):
e_col = 2 + 3*p
ax_e.plot(time, data[:, e_col], label=fr"$\mathcal{{E}}$ P{p+1}")
ax_e.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_e.set_ylabel(r"$\mathcal{E}$ (km$^2$/s$^2$)", fontsize=12)
ax_e.set_title(r"Energy vs. Time (Lowest $\ell$)", fontsize=14)
ax_e.legend(fontsize=10)
ax_e.grid(True)
# 3) % deviation from average energy
# We'll build an array (nsteps x nlowest) of energies
nsteps = data.shape[0]
energies = np.zeros((nsteps, nlowest), dtype=np.float32)
for p in range(nlowest):
e_col = 2 + 3*p
energies[:, p] = data[:, e_col]
# compute each particle's time-averaged energy
for p in range(nlowest):
# average of that particle's energy over time
single_mean = np.mean(energies[:, p])
pct_dev = (energies[:, p] - single_mean) / single_mean * 100.0
ax_dev.plot(time, pct_dev, label=f"P{p+1}")
ax_dev.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_dev.set_ylabel(r"$\Delta\mathcal{E}/\langle\mathcal{E}\rangle$ (%)", fontsize=12)
ax_dev.set_title(r"Energy Deviation (Lowest $\ell$)", fontsize=14)
ax_dev.legend(fontsize=10)
ax_dev.grid(True)
# Save & close figure
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
gc.collect()
[docs]
def plot_debug_energy_compare(input_file, output_file="results/energy_compare.png"):
"""
Create multi-panel diagnostic plot comparing energy calculation methods.
Parameters
----------
input_file : str
Path to the input data file containing energy comparison data.
output_file : str, optional
Path to save the output plot. Default is "results/energy_compare.png".
Returns
-------
None
Notes
-----
Input file should have 8 columns:
- 0: snapIdx - Snapshot index
- 1: time(Myr) - Time in megayears
- 2: radius(kpc) - Radius in kiloparsecs
- 3: approxE - Approximated energy
- 4: dynE - Dynamically calculated energy
- 5: (dynE - approxE) - Difference between calculation methods
- 6: KE - Kinetic Energy
- 7: PE - Potential Energy
Creates a 5-panel figure in a 2x3 layout:
- Top row (3 panels):
[0,0]: Radius vs. time
[0,1]: Dynamic energy vs. time
[0,2]: Energy deviation from mean (%) vs. time
- Bottom row (2 panels + 1 blank):
[1,0]: Kinetic energy vs. time
[1,1]: Potential energy vs. time
[1,2]: (empty placeholder)
"""
data = safe_load_and_filter_bin(
input_file, ncol_debug_energy_compare, dtype=np.float32)
if data is None:
print(
truncate_and_pad_string(f"WARNING: {input_file} not found or no valid data. Skipping energy-compare plot."))
return
# Filter out any invalid (NaN/Inf) rows
mask = np.all(np.isfinite(data), axis=1)
data = data[mask]
if data.size == 0:
print(
truncate_and_pad_string(f"Warning: After filtering, {input_file} had no valid rows. Skipping."))
return
# Extract data columns from the energy comparison file
snap_idx = data[:, 0] # Snapshot index
time_myr = data[:, 1] # Time in megayears
radius_kpc = data[:, 2] # Radius in kiloparsecs
approx_e = data[:, 3] # Approximated energy
dyn_e = data[:, 4] # Dynamically calculated energy
diff_e = data[:, 5] # Difference between methods (dynE - approxE)
ke_vals = data[:, 6] # Kinetic energy
pe_vals = data[:, 7] # Potential energy
# Create a 2-row by 3-column figure
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(18, 10))
# Top row subplots
ax_r = axes[0, 0] # radius vs time
ax_e = axes[0, 1] # dynE vs time
ax_diff = axes[0, 2] # energy deviation from mean
# Bottom row subplots
ax_ke = axes[1, 0] # KE vs time
ax_pe = axes[1, 1] # PE vs time
# The last panel [1,2] we turn off or leave blank
axes[1, 2].axis('off')
# 1) Radius vs Time
ax_r.plot(time_myr, radius_kpc, 'b-', label=r"$r$")
ax_r.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_r.set_ylabel(r"$r$ (kpc)", fontsize=12)
ax_r.set_title(r"Radius vs. Time", fontsize=14)
ax_r.grid(True)
ax_r.legend(fontsize=10)
# 2) dynE vs. Time (Using \mathcal{E}) - negated
# Plot the negated energy values first to get auto-scaling
neg_dyn_e = -dyn_e
ax_e.plot(time_myr, neg_dyn_e, 'r-', label=r"$-\mathcal{E}_{\rm dyn}$")
ax_e.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_e.set_ylabel(r"$-\mathcal{E}$ (km$^2$/s$^2$)", fontsize=12)
ax_e.set_title(r"Energy vs. Time ($-\mathcal{E}_{\rm dyn}$)", fontsize=14)
# Get the auto-scaled range after plotting
epsilon_range_min, epsilon_range_max = ax_e.get_ylim()
# Calculate appropriate variation scale based on potential energy variation
psi_range = np.abs(np.max(pe_vals) - np.min(pe_vals))
delta = psi_range / 3.0 # Characteristic energy variation scale
neg_mean_e = np.mean(neg_dyn_e) # Mean of negated energy
# Calculate new limits that preserve both the auto-scaled range and our calculated range
new_min = min(epsilon_range_min, neg_mean_e - 0.5*delta)
new_max = max(epsilon_range_max, neg_mean_e + 0.5*delta)
# Set the expanded y-axis limits
ax_e.set_ylim(new_min, new_max)
ax_e.grid(True)
ax_e.legend(fontsize=10)
# 3) Energy deviation from mean (Using \mathcal{E})
# Calculate average deviation from time average
mean_dyn = np.mean(dyn_e)
# Prevent division by zero by using a minimal denominator value
eps_d = mean_dyn if abs(mean_dyn) > 1e-30 else 1e-30
dev_dyn_pct = 100.0 * (dyn_e - mean_dyn) / eps_d
# Plot shows the deviation percentage without negation
# as this represents relative change which is independent of sign
ax_diff.plot(time_myr, dev_dyn_pct, 'r-',
label=r"$(\mathcal{E}_{\rm dyn} - \langle\mathcal{E}_{\rm dyn}\rangle)/\langle\mathcal{E}_{\rm dyn}\rangle$ (%)")
ax_diff.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_diff.set_ylabel(r"$\Delta\mathcal{E}/\langle\mathcal{E}\rangle$ (%)", fontsize=12)
ax_diff.set_title(r"Energy Deviation from Mean", fontsize=14)
ax_diff.grid(True)
ax_diff.legend(fontsize=10)
# 4) Kinetic Energy vs. Time
ax_ke.plot(time_myr, ke_vals, 'c-', label=r"$K$")
ax_ke.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_ke.set_ylabel(r"$K$ (km$^2$/s$^2$)", fontsize=12)
ax_ke.set_title(r"Kinetic Energy vs. Time", fontsize=14)
ax_ke.grid(True)
ax_ke.legend(fontsize=10)
# 5) Potential Energy vs. Time (negated)
ax_pe.plot(time_myr, -pe_vals, 'y-', label=r"$-\Psi$")
ax_pe.set_xlabel(r"$t$ (Myr)", fontsize=12)
ax_pe.set_ylabel(r"$-\Psi$ (km$^2$/s$^2$)", fontsize=12)
ax_pe.set_title(r"Potential Energy vs. Time", fontsize=14)
ax_pe.grid(True)
ax_pe.legend(fontsize=10)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
gc.collect()
# Log to file only, console output handled by progress tracker
logger.info(f"Plot saved: {output_file}")
# The calling function should use update_combined_progress instead
# of this function directly updating the progress bar
[docs]
class Configuration:
"""
Configuration class that encapsulates all parameters and settings for the nsphere_plot.py script.
This reduces the reliance on global variables and improves code maintainability and testability.
"""
[docs]
def __init__(self, args=None):
"""
Initialize the configuration with command line arguments.
Parameters
----------
args : argparse.Namespace, optional
The parsed command line arguments.
If None, arguments will be parsed from sys.argv.
"""
# Command line arguments
self.args = args if args else self._parse_arguments()
# Set parameters from arguments, with validation
self._setup_parameters()
def _parse_arguments(self):
"""
Parse command line arguments for configuring the visualization options.
Returns
-------
argparse.Namespace
The parsed command line arguments with visualization settings.
Notes
-----
This method creates an ArgumentParser with options for controlling which
visualizations to generate and sets appropriate defaults based on lastparams.dat.
"""
# Note: showing_help flag is already set at the beginning of the script
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Generate plots and animations from simulation data.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Get default suffix from lastparams.dat
default_suffix = read_lastparams(return_suffix=True)
parser.add_argument('--suffix', type=str, default=default_suffix,
help='Suffix for data files (e.g., _adp.leap.adp.levi_30000_1001_5)')
parser.add_argument('--start', type=int, default=0,
help='Starting snapshot number')
parser.add_argument('--end', type=int, default=0,
help='Ending snapshot number (0 means use all available)')
parser.add_argument('--step', type=int, default=1,
help='Step size between snapshots')
parser.add_argument('--fps', type=int, default=10,
help='Frames per second in the output animations')
# Flags for only generating specific visualization groups
parser.add_argument('--phase-space', action='store_true',
help='When used, ONLY generate phase space plots and animations')
parser.add_argument('--phase-comparison', action='store_true',
help='When used, ONLY generate initial vs. final phase space comparison')
parser.add_argument('--profile-plots', action='store_true',
help='When used, ONLY generate profile plots (density, mass, psi, etc.)')
parser.add_argument('--trajectory-plots', action='store_true',
help='When used, ONLY generate trajectory and diagnostic plots')
parser.add_argument('--2d-histograms', action='store_true',
help='When used, ONLY generate 2D histograms')
parser.add_argument('--convergence-tests', action='store_true',
help='When used, ONLY generate convergence test plots')
parser.add_argument('--animations', action='store_true',
help='When used, ONLY generate animations')
parser.add_argument('--energy-plots', action='store_true',
help='When used, ONLY generate energy plots')
# Flags for skipping specific visualization groups
# Flag for generating 1D variable distributions
parser.add_argument('--distributions', action='store_true',
help='When used, ONLY generate 1D variable distributions (Initial vs Final)')
parser.add_argument('--no-phase-space', action='store_true',
help='Skip phase space plots and animations')
parser.add_argument('--no-phase-comparison', action='store_true',
help='Skip initial vs. final phase space comparison')
parser.add_argument('--no-profile-plots', action='store_true',
help='Skip profile plots (density, mass, psi, etc.)')
parser.add_argument('--no-trajectory-plots', action='store_true',
help='Skip trajectory and diagnostic plots')
parser.add_argument('--no-histograms', action='store_true',
help='Skip 2D histograms')
parser.add_argument('--no-convergence-tests', action='store_true',
help='Skip convergence test plots')
parser.add_argument('--no-animations', action='store_true',
help='Skip animations')
parser.add_argument('--no-energy-plots', action='store_true',
help='Skip energy plots')
parser.add_argument('--no-distributions', action='store_true',
help='Skip 1D variable distributions (Initial vs Final)')
parser.add_argument('--log', action='store_true',
help='Enable detailed logging to file (default: only errors and warnings)')
parser.add_argument('--paced', action='store_true',
help='Paced mode: enable delays between sections (default: fast mode with no delays)')
return parser.parse_args()
def _setup_parameters(self):
"""
Configure parameters based on command line arguments and validate settings.
This method initializes the following configuration properties:
- suffix: File suffix for identifying data files
- fps: Frames per second for animations
- duration: Frame duration in milliseconds
- start_snap, end_snap, step_snap: Snapshot filtering parameters
It also handles global flags like enable_logging and paced_mode, and extracts
additional parameters from the suffix (npts, Ntimes, tfinal_factor, file_tag).
"""
global showing_help, enable_logging
# Set the suffix for data files
self.suffix = self.args.suffix
# Calculate frame duration from fps
self.fps = self.args.fps
self.duration = 1000.0 / self.fps
# Set the start, end, and step parameters for filtering Rank files
self.start_snap = self.args.start
self.end_snap = self.args.end
self.step_snap = self.args.step
# Handle logging flags (--log)
global enable_logging, paced_mode
if self.args.log:
enable_logging = True
# Handle paced mode flag (--paced)
global section_delay, progress_delay
if self.args.paced:
paced_mode = True
section_delay = 5.0 # Delay between different sections in seconds when paced mode is enabled
progress_delay = 2.0 # Delay between progress bars in seconds when paced mode is enabled
# Parameter display is centralized in the main function banner
# Extract parameters from suffix for compatibility with existing code
self._extract_params_from_suffix()
def _extract_params_from_suffix(self):
"""
Parse the suffix string to extract simulation parameters.
This method parses the suffix pattern "_[file_tag]_npts_Ntimes_tfinal_factor"
to extract the following parameters:
- npts: Number of particles in the simulation
- Ntimes: Number of time steps in the simulation
- tfinal_factor: Final time factor of the simulation
- file_tag: Optional tag identifying the simulation run
If parsing fails, it falls back to reading from lastparams.dat.
Notes
-----
These parameters are used throughout the code for file paths,
plot labels, and determining output filenames.
"""
parts = self.suffix.strip('_').split('_')
# Default values
self.npts = 30000
self.Ntimes = 1001
self.tfinal_factor = 5
self.file_tag = ""
if len(parts) >= 3:
try:
# The last three parts should always be npts, Ntimes, tfinal_factor
self.npts = int(parts[-3])
self.Ntimes = int(parts[-2])
self.tfinal_factor = int(parts[-1])
# The file_tag is everything before the last three parts
if len(parts) > 3:
# Join any remaining parts as the file_tag
self.file_tag = '_'.join(parts[:-3])
except (ValueError, IndexError):
# Fallback if parsing fails, read from lastparams.dat
self.npts, self.Ntimes, self.tfinal_factor, self.file_tag = read_lastparams(return_suffix=False)
else:
# Fallback to reading from lastparams.dat
self.npts, self.Ntimes, self.tfinal_factor, self.file_tag = read_lastparams(return_suffix=False)
# If command-line suffix was provided, keep it as-is; otherwise use the reconstructed one
if not self.args.suffix:
# Rebuild suffix to ensure consistency
if self.file_tag:
self.suffix = f"_{self.file_tag}_{self.npts}_{self.Ntimes}_{self.tfinal_factor}"
else:
self.suffix = f"_{self.npts}_{self.Ntimes}_{self.tfinal_factor}"
[docs]
def setup_file_paths(self):
"""
Generate a dictionary mapping file types to their full file paths.
Returns
-------
dict
Dictionary where keys are file type identifiers (e.g., "particles",
"density_profile") and values are the corresponding file paths with
the configured suffix.
Notes
-----
This method constructs paths for all possible data files that might be
used in the visualization process. The actual existence of these files
is checked later when they are accessed.
"""
return {
"particles": f"data/particles{self.suffix}.dat",
"particlesfinal": f"data/particlesfinal{self.suffix}.dat",
"integrand": f"data/integrand{self.suffix}.dat",
"density_profile": f"data/density_profile{self.suffix}.dat",
"massprofile": f"data/massprofile{self.suffix}.dat",
"psiprofile": f"data/Psiprofile{self.suffix}.dat",
"dpsi_dr": f"data/dpsi_dr{self.suffix}.dat",
"drho_dpsi": f"data/drho_dpsi{self.suffix}.dat",
"f_of_E": f"data/f_of_E{self.suffix}.dat",
"df_fixed_radius": f"data/df_fixed_radius{self.suffix}.dat",
"combined_histogram": f"data/combined_histogram{self.suffix}.dat",
"trajectories": f"data/trajectories{self.suffix}.dat",
"single_trajectory": f"data/single_trajectory{self.suffix}.dat",
"energy_and_angular_momentum": f"data/energy_and_angular_momentum_vs_time{self.suffix}.dat",
"hist_init": f"data/2d_hist_initial{self.suffix}.dat",
"hist_final": f"data/2d_hist_final{self.suffix}.dat",
"lowest_l_trajectories": f"data/lowest_l_trajectories{self.suffix}.dat",
"debug_energy_compare": f"data/debug_energy_compare{self.suffix}.dat"
}
[docs]
def only_specific_visualizations(self):
"""
Determine if the user requested specific visualization types only.
Returns
-------
bool
True if any "only" flag is specified (e.g., --phase-space, --profile-plots),
False if running in normal mode where all visualizations are generated.
Notes
-----
When "only" flags are active, the script will generate only those specific
visualization types and skip all others, regardless of --no-* flags.
"""
return (self.args.phase_space or self.args.phase_comparison or
self.args.profile_plots or self.args.trajectory_plots or
getattr(self.args, '2d_histograms', False) or self.args.convergence_tests or
self.args.animations or self.args.energy_plots or
self.args.distributions)
[docs]
def need_to_process_rank_files(self):
"""
Determine if the visualization plan requires processing rank data files.
Returns
-------
bool
True if rank files need to be processed for animations or energy plots,
False if they can be skipped based on user-selected visualization options.
Notes
-----
Rank files are needed for animations and energy plots. This method checks
if either of these visualization types are specifically requested (via "only" flags)
or if they haven't been explicitly excluded (via "no-" flags) in normal mode.
"""
only_flags_active = self.only_specific_visualizations()
return (self.args.animations or self.args.energy_plots or
(not only_flags_active and (not self.args.no_animations or not self.args.no_energy_plots)))
[docs]
def parse_arguments():
"""
Parse command line arguments for the nsphere_plot.py script.
Returns
-------
argparse.Namespace
The parsed arguments.
"""
config = Configuration()
return config.args
[docs]
def setup_global_parameters(args):
"""
Setup global parameters based on command line arguments.
Parameters
----------
args : argparse.Namespace
The parsed command line arguments.
Returns
-------
tuple
The calculated parameters (suffix, start_snap, end_snap, step_snap, duration).
"""
global suffix, start_snap, end_snap, step_snap, duration
config = Configuration(args)
suffix = config.suffix
start_snap = config.start_snap
end_snap = config.end_snap
step_snap = config.step_snap
duration = config.duration
return suffix, start_snap, end_snap, step_snap, duration
[docs]
def setup_file_paths(suffix):
"""
Setup file paths for all data files using the given suffix.
Parameters
----------
suffix : str
The suffix to use for all file names.
Returns
-------
dict
Dictionary containing all file paths.
"""
temp_config = Configuration()
temp_config.suffix = suffix
return temp_config.setup_file_paths()
[docs]
def process_profile_plots(file_paths):
"""
Process various profile data files and generate plots.
Parameters
----------
file_paths : dict
Dictionary of file paths for various data files.
"""
global suffix
# Create the results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Calculate total number of data files to load (for progress tracking)
# Potential data files to load: density, mass, psi, dpsi, drho, f_of_E, df_fixed, combined
total_data_files = 8
# Start progress tracking for data loading
start_combined_progress("profile_data_loading", total_data_files)
# Track plots to be generated
plots_to_generate = []
output_files = []
# First, prepare all the data that needs to be plotted
prepared_plots = []
# Plot density profile (now using rho, not 4*pi*r^2*rho)
logger.info(f"Loading density profile data from {file_paths['density_profile']}")
update_combined_progress("profile_data_loading", "loading_density_profile")
dens_data = safe_load_and_filter_bin(
file_paths["density_profile"], ncol_density_profile, dtype=np.float32)
if dens_data is not None:
gc.collect()
mask = np.all(np.isfinite(dens_data), axis=1)
dens_data = dens_data[mask]
output_file = f"results/density_profile{suffix}.png"
prepared_plots.append(("density", dens_data[:, 0], dens_data[:, 1], output_file))
# Plot mass profile
logger.info(f"Loading mass profile data from {file_paths['massprofile']}")
update_combined_progress("profile_data_loading", "loading_mass_profile")
mass_data = safe_load_and_filter_bin(
file_paths["massprofile"], ncol_mass_profile, dtype=np.float32)
if mass_data is not None:
gc.collect()
mask = np.all(np.isfinite(mass_data), axis=1)
mass_data = mass_data[mask]
output_file = f"results/mass_enclosed{suffix}.png"
prepared_plots.append(("mass", mass_data[:, 0], mass_data[:, 1], output_file))
# Plot psi profile
logger.info(f"Loading psi profile data from {file_paths['psiprofile']}")
update_combined_progress("profile_data_loading", "loading_psi_profile")
psi_data = safe_load_and_filter_bin(
file_paths["psiprofile"], ncol_psi_profile, dtype=np.float32)
if psi_data is not None:
gc.collect()
mask = np.all(np.isfinite(psi_data), axis=1)
psi_data = psi_data[mask]
VEL_CONV_SQ = (1.02271e-3*1.02271e-3)
psi_data_scaled = psi_data[:, 1]*VEL_CONV_SQ # Scale to km^2/s^2
output_file = f"results/psi_profile{suffix}.png"
prepared_plots.append(("psi", psi_data[:, 0], psi_data_scaled, output_file))
# Plot dpsi/dr profile
logger.info(f"Loading dpsi/dr profile data from {file_paths['dpsi_dr']}")
update_combined_progress("profile_data_loading", "loading_dpsi_dr_profile")
dpsi_data = safe_load_and_filter_bin(
file_paths["dpsi_dr"], ncol_dpsi_dr, dtype=np.float32)
if dpsi_data is not None:
gc.collect()
mask = np.all(np.isfinite(dpsi_data), axis=1)
dpsi_data = dpsi_data[mask]
output_file = f"results/dpsi_dr{suffix}.png"
prepared_plots.append(("dpsi", dpsi_data[:, 0], dpsi_data[:, 1], output_file))
# Plot drho/dpsi profile
logger.info(f"Loading drho/dpsi profile data from {file_paths['drho_dpsi']}")
update_combined_progress("profile_data_loading", "loading_drho_dpsi_profile")
drho_data = safe_load_and_filter_bin(
file_paths["drho_dpsi"], ncol_drho_dpsi, dtype=np.float32)
if drho_data is not None:
gc.collect()
mask = np.all(np.isfinite(drho_data), axis=1)
drho_data = drho_data[mask]
output_file = f"results/drho_dpsi{suffix}.png"
prepared_plots.append(("drho", drho_data[:, 0], drho_data[:, 1], output_file))
# Plot f(E) profile
logger.info(f"Loading f(E) profile data from {file_paths['f_of_E']}")
update_combined_progress("profile_data_loading", "loading_f_of_E_profile")
fE_data = safe_load_and_filter_bin(
file_paths["f_of_E"], ncol_f_of_E, dtype=np.float32)
if fE_data is not None:
gc.collect()
mask = np.all(np.isfinite(fE_data), axis=1)
fE_data = fE_data[mask]
output_file = f"results/f_of_E{suffix}.png"
prepared_plots.append(("f_of_E", fE_data[:, 0], fE_data[:, 1], output_file))
# Plot df at fixed radius
logger.info(f"Loading df at fixed radius data from {file_paths['df_fixed_radius']}")
update_combined_progress("profile_data_loading", "loading_df_fixed_radius_profile")
# Load directly as float32 array - each row has two 4-byte floats
try:
# Load binary data
with open(file_paths["df_fixed_radius"], 'rb') as f:
raw_data = np.fromfile(f, dtype=np.float32)
# Reshape into rows of 2 columns
if len(raw_data) >= 2: # Make sure we have at least one complete row
rows = len(raw_data) // 2
df_fixed_data = raw_data.reshape(rows, 2)
logger.info(f"Successfully loaded df_fixed_radius data: {rows} rows of velocity/distribution data")
else:
logger.warning(f"Not enough data in {file_paths['df_fixed_radius']}")
df_fixed_data = None
except Exception as e:
logger.error(f"Error loading df_fixed_radius data: {e}")
# Fall back to standard loading method
df_fixed_data = safe_load_and_filter_bin(
file_paths["df_fixed_radius"], ncol_df_fixed_radius, dtype=np.float32)
if df_fixed_data is not None:
gc.collect()
# Filter out any NaN or Inf values
mask = np.all(np.isfinite(df_fixed_data), axis=1)
df_fixed_data = df_fixed_data[mask]
# Log detailed information about the data
logger.info(f"df_fixed_radius data loaded successfully, shape: {df_fixed_data.shape}")
if df_fixed_data.size > 0:
logger.info(f"df_fixed_radius data range - v: [{np.min(df_fixed_data[:,0])}, {np.max(df_fixed_data[:,0])}], " +
f"f: [{np.min(df_fixed_data[:,1])}, {np.max(df_fixed_data[:,1])}]")
logger.info(f"First 5 rows: {df_fixed_data[:5]}")
output_file = f"results/df_fixed_radius{suffix}.png"
prepared_plots.append(("df_fixed", df_fixed_data[:, 0], df_fixed_data[:, 1], output_file))
# Combined histogram
combined_histogram_file = file_paths['combined_histogram']
logger.info(f"Loading combined histogram data from {combined_histogram_file}")
update_combined_progress("profile_data_loading", "loading_combined_histogram")
output_file = f"results/combined_radial_distribution{suffix}.png"
prepared_plots.append(("combined", combined_histogram_file, None, output_file))
# Now generate all the plots with progress tracking
total_plots = len(prepared_plots)
for i, (plot_type, x_data, y_data, output_file) in enumerate(prepared_plots, 1):
try:
if plot_type == "density":
plot_density(x_data, y_data, output_file)
elif plot_type == "mass":
plot_mass_enclosed(x_data, y_data, output_file)
elif plot_type == "psi":
plot_psi(x_data, y_data, output_file)
elif plot_type == "dpsi":
plot_dpsi_dr(x_data, y_data, output_file)
elif plot_type == "drho":
plot_drho_dpsi(x_data, y_data, output_file)
elif plot_type == "f_of_E":
plot_f_of_E(x_data, y_data, output_file)
elif plot_type == "df_fixed":
# Use a fixed radius of 200.0 kpc
plot_df_at_fixed_radius(x_data, y_data, 200.0, output_file)
elif plot_type == "combined":
plot_combined_histogram_from_file(x_data, output_file)
# Log to file and update console display with progress
log_plot_saved(output_file, current=i, total=total_plots)
except Exception as e:
logger.error(f"Error generating {plot_type} plot: {str(e)}")
print_status(f"Error generating {plot_type} plot: {str(e)}")
continue
[docs]
def process_trajectory_energy_plots(file_paths, include_angular_momentum=True):
"""
Process energy plots that are part of the trajectory plots but should also
be included in the energy plots category.
Parameters
----------
file_paths : dict
Dictionary of file paths for various data files.
include_angular_momentum : bool, optional
Whether to include angular momentum plots. Default is True.
"""
global suffix
# Create the results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Prepare plots to be generated
prepared_plots = []
# Energy vs time plot
energy_output_file = f"results/energy_vs_time{suffix}.png"
prepared_plots.append(("energy", file_paths["energy_and_angular_momentum"], energy_output_file))
# Angular momentum vs time (if requested)
if include_angular_momentum:
angular_output_file = f"results/angular_momentum_vs_time{suffix}.png"
prepared_plots.append(("angular", file_paths["energy_and_angular_momentum"], angular_output_file))
# Debug energy comparison (if available)
debug_file = file_paths["debug_energy_compare"]
if os.path.exists(debug_file):
# Define output filename for the energy comparison plot
debug_output_file = f"results/energy_compare{suffix}.png"
prepared_plots.append(("debug", debug_file, debug_output_file))
else:
logger.warning(f"Energy compare file {debug_file} not found; skipping energy compare plot.")
# Define a unique section key for this set of plots
section_key = "additional_energy_plots"
total_plots = len(prepared_plots)
# Start combined progress tracking for these plots
if total_plots > 0:
start_combined_progress(section_key, total_plots)
# Generate all plots with progress tracking
for plot_type, input_file, output_file in prepared_plots:
try:
if plot_type == "energy":
plot_energy_time(input_file, output_file)
elif plot_type == "angular":
plot_angular_momentum_time(input_file, output_file)
elif plot_type == "debug":
plot_debug_energy_compare(input_file, output_file=output_file)
# When called from trajectory_plots, update the combined progress tracker
if "trajectory_plots" in _combined_plot_trackers:
update_combined_progress("trajectory_plots", output_file)
else:
# Otherwise, use our section tracker
update_combined_progress(section_key, output_file)
except Exception as e:
logger.error(f"Error generating {plot_type} plot: {str(e)}")
print_status(f"Error generating {plot_type} plot: {str(e)}")
continue
else:
# If no plots to generate, just log and continue
logger.info("No additional energy plots to generate.")
[docs]
def process_trajectory_plots(file_paths):
"""
Process trajectory and diagnostic plots.
Parameters
----------
file_paths : dict
Dictionary of file paths for various data files.
"""
global suffix
# Create the results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
# Calculate total number of data files to load (for progress tracking)
# Four different data files to potentially load: trajectories, single_trajectory, lowest_l_trajectories, energy_and_angular_momentum
total_data_files = 4
# Start progress tracking for data loading
start_combined_progress("trajectory_data_loading", total_data_files)
# Prepare plots to be generated
prepared_plots = []
# Trajectory plots - check and load
logger.info(f"Loading trajectory data from {file_paths['trajectories']}")
update_combined_progress("trajectory_data_loading", "loading_trajectories")
if os.path.exists(file_paths["trajectories"]):
trajectories_output = f"results/trajectories{suffix}.png"
prepared_plots.append(("trajectories", file_paths["trajectories"], trajectories_output))
else:
logger.warning(f"Trajectory file {file_paths['trajectories']} not found, skipping.")
# Single trajectory - check and load
logger.info(f"Loading single trajectory data from {file_paths['single_trajectory']}")
update_combined_progress("trajectory_data_loading", "loading_single_trajectory")
if os.path.exists(file_paths["single_trajectory"]):
single_trajectory_output = f"results/single_trajectory{suffix}.png"
prepared_plots.append(("single", file_paths["single_trajectory"], single_trajectory_output))
else:
logger.warning(f"Single trajectory file {file_paths['single_trajectory']} not found, skipping.")
# Lowest L 3-panel - check and load
logger.info(f"Loading lowest L trajectories data from {file_paths['lowest_l_trajectories']}")
update_combined_progress("trajectory_data_loading", "loading_lowest_l_trajectories")
if os.path.exists(file_paths["lowest_l_trajectories"]):
lowest_l_output = f"results/lowestL_3panel{suffix}.png"
prepared_plots.append(("lowest_l", file_paths["lowest_l_trajectories"], lowest_l_output))
else:
logger.warning(f"Lowest L trajectories file {file_paths['lowest_l_trajectories']} not found, skipping.")
# Energy and angular momentum data - check and load
logger.info(f"Loading energy and angular momentum data from {file_paths['energy_and_angular_momentum']}")
update_combined_progress("trajectory_data_loading", "loading_energy_angular_momentum")
# Generate trajectory plots using the combined progress tracker
for plot_type, input_file, output_file in prepared_plots:
try:
if plot_type == "trajectories":
plot_trajectories(input_file, output_file)
elif plot_type == "single":
plot_single_trajectory(input_file, output_file)
elif plot_type == "lowest_l":
plot_lowestL_trajectories_3panel(input_file, output_file)
# Update the combined progress tracker (shared with energy plots)
update_combined_progress("trajectory_plots", output_file)
except Exception as e:
logger.error(f"Error generating {plot_type} plot: {str(e)}")
print_status(f"Error generating {plot_type} plot: {str(e)}")
continue
# Also generate the energy plots from this category
process_trajectory_energy_plots(file_paths)
# Global variables to store snapshot data for animation rendering
mass_snapshots = []
density_snapshots = []
psi_snapshots = []
# Global variables to store max values for consistent scaling across frames
mass_max_value = 0.0
density_max_value = 0.0
psi_max_value = 0.0
[docs]
def calculate_global_max_values():
"""
Calculate the maximum values of mass, density, and potential across all snapshots.
This ensures consistent scaling in animations.
Returns
-------
tuple
(mass_max, density_max, psi_max) - Maximum values for each quantity
"""
global mass_snapshots, density_snapshots, psi_snapshots
global mass_max_value, density_max_value, psi_max_value
# Find max mass value
if mass_snapshots:
mass_max_value = max(np.max(mass_data) for _, _, mass_data in mass_snapshots)
# Add 5% margin
mass_max_value *= 1.05
# Find max density value (4*pi*r^2*rho)
if density_snapshots:
density_max_value = max(np.max(density_data) for _, _, density_data in density_snapshots)
# Add 5% margin
density_max_value *= 1.05
# Find max psi value
if psi_snapshots:
psi_max_value = max(np.max(psi_data) for _, _, psi_data in psi_snapshots)
# Add 5% margin
psi_max_value *= 1.05
logger.info(f"Global max values: Mass={mass_max_value:.3e}, Density={density_max_value:.3e}, Psi={psi_max_value:.3e}")
return (mass_max_value, density_max_value, psi_max_value)
# Frame rendering functions for animation
[docs]
def render_mass_frame(frame_data):
"""
Render a single frame of mass profile animation.
Parameters
----------
frame_data : tuple
Tuple containing (snapshot_data, tfinal_factor, total_frames, r_max, project_root_path) where
snapshot_data is the (snap, radius, mass) tuple from mass_snapshots.
Returns
-------
numpy.ndarray
Image data for the rendered frame.
"""
# Access only mass_max_value from globals, not mass_snapshots
global mass_max_value
# Unpack the passed data tuple
snapshot_data, tfinal_factor, total_snapshots, r_max, project_root_path = frame_data
# Unpack the actual data from the snapshot tuple
snap, radius, mass = snapshot_data
# Calculate time in dynamical times using the passed total_snapshots
if total_snapshots and total_snapshots > 1:
t_dyn_fraction = snap / (total_snapshots - 1) * tfinal_factor
else:
# Fallback if total_snapshots is 0 or 1
t_dyn_fraction = 0.0
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Mass Profile Evolution", fontsize=14, pad=20)
ax.set_xlabel(r"$r$ (kpc)", fontsize=12)
ax.set_ylabel(r"$M(r)$ (M$_\odot$)", fontsize=12)
# Use the pre-calculated r_max value that was passed in
ax.set_xlim(0, r_max)
# Use calculated maximum mass value
ax.set_ylim(0, mass_max_value if mass_max_value > 0 else 6.2e11)
ax.grid(True, which='both', linestyle='--')
ax.plot(radius, mass, lw=2)
# Add text in upper right corner showing time in dynamical times
# Regular black text on semi-transparent white background for mass profile
ax.text(0.98, 0.95, f"$t = {t_dyn_fraction:.2f}\\,t_{{\\rm dyn}}$",
transform=ax.transAxes, ha='right', va='top', fontsize=12,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=3))
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
plt.close(fig)
buf.seek(0)
img = imageio.imread(buf)
buf.close()
return img
[docs]
def render_density_frame(frame_data):
"""
Render a single frame of density profile animation.
Note: Assumes density_snapshots contains 4*pi*r^2*rho.
Parameters
----------
frame_data : tuple
Tuple containing (snapshot_data, tfinal_factor, total_frames, r_max, project_root_path) where
snapshot_data is the (snap, radius, density) tuple from density_snapshots.
Returns
-------
numpy.ndarray
Image data for the rendered frame.
"""
# Access only density_max_value from globals, not density_snapshots
global density_max_value
# Unpack the passed data tuple
snapshot_data, tfinal_factor, total_snapshots, r_max, project_root_path = frame_data
# Unpack the actual data from the snapshot tuple
snap, radius, density = snapshot_data # Assumed this is 4*pi*r^2*rho
# Calculate time in dynamical times using the passed total_snapshots
if total_snapshots and total_snapshots > 1:
t_dyn_fraction = snap / (total_snapshots - 1) * tfinal_factor
else:
# Fallback if total_snapshots is 0 or 1
t_dyn_fraction = 0.0
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Density Profile Evolution", fontsize=14, pad=20)
ax.set_xlabel(r"$r$ (kpc)", fontsize=12)
ax.set_ylabel(r"$4\pi r^2 \rho(r)$ (M$_\odot$/kpc)", fontsize=12) # Label matches data
# Use the pre-calculated r_max value that was passed in
ax.set_xlim(0, r_max)
# Use calculated maximum density value
ax.set_ylim(0, density_max_value if density_max_value > 0 else 1.2e10)
ax.grid(True, which='both', linestyle='--')
# Directly plot the 'density' variable, as it's assumed to be 4*pi*r^2*rho already
ax.plot(radius, density, lw=2)
# Add text in upper right corner showing time in dynamical times
# Regular black text on semi-transparent white background for density profile
ax.text(0.98, 0.95, f"$t = {t_dyn_fraction:.2f}\\,t_{{\\rm dyn}}$",
transform=ax.transAxes, ha='right', va='top', fontsize=12,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=3))
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
plt.close(fig)
buf.seek(0)
img = imageio.imread(buf)
buf.close()
return img
[docs]
def render_psi_frame(frame_data):
"""
Render a single frame of psi profile animation.
Parameters
----------
frame_data : tuple
Tuple containing (snapshot_data, tfinal_factor, total_frames, r_max, project_root_path) where
snapshot_data is the (snap, radius, psi) tuple from psi_snapshots.
Returns
-------
numpy.ndarray
Image data for the rendered frame.
"""
# Access only psi_max_value from globals, not psi_snapshots
global psi_max_value
# Unpack the passed data tuple
snapshot_data, tfinal_factor, total_snapshots, r_max, project_root_path = frame_data
# Unpack the actual data from the snapshot tuple
snap, radius, psi = snapshot_data
# Calculate time in dynamical times using the passed total_snapshots
if total_snapshots and total_snapshots > 1:
t_dyn_fraction = snap / (total_snapshots - 1) * tfinal_factor
else:
# Fallback if total_snapshots is 0 or 1
t_dyn_fraction = 0.0
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_title("Potential Profile Evolution", fontsize=14, pad=20)
ax.set_xlabel(r"$r$ (kpc)", fontsize=12)
ax.set_ylabel(r"$\Psi(r)$ (km$^2$/s$^2$)", fontsize=12)
# Use the pre-calculated r_max value that was passed in
ax.set_xlim(0, r_max)
# Use calculated maximum psi value
ax.set_ylim(0, psi_max_value if psi_max_value > 0 else 0.072)
ax.grid(True, which='both', linestyle='--')
ax.plot(radius, psi, lw=2)
# Add text in upper right corner showing time in dynamical times
# Regular black text on semi-transparent white background for potential profile
ax.text(0.98, 0.95, f"$t = {t_dyn_fraction:.2f}\\,t_{{\\rm dyn}}$",
transform=ax.transAxes, ha='right', va='top', fontsize=12,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=3))
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
plt.close(fig)
buf.seek(0)
img = imageio.imread(buf)
buf.close()
return img
[docs]
def process_particle_data(data, n_cols):
"""
Extracts key quantities and calculates total velocity.
Parameters
----------
data : ndarray
Particle data array
n_cols : int
Number of columns in the data
Returns
-------
tuple
(radius, radial_velocity, angular_momentum, total_velocity) for valid particles
"""
# These use global constants defined at the top of the file
if data is None or data.shape[0] == 0:
return None, None, None, None
if n_cols == ncol_particles_initial:
radii, radial_velocity, angular_momentum = data[:, 0], data[:, 1], data[:, 2]
elif n_cols == ncol_Rank_Mass_Rad_VRad_unsorted:
radii, radial_velocity, angular_momentum = data[:, 2], data[:, 3], data[:, 6]
else:
log_message(f"Error: Unsupported column count ({n_cols})", level="error")
return None, None, None, None
nonzero_mask = radii > 0
if not np.any(nonzero_mask):
log_message("Warning: No particles with positive radius found.", level="warning")
return None, None, None, None
radii_nz = radii[nonzero_mask]
radial_velocity_nz = radial_velocity[nonzero_mask]
angular_momentum_nz = angular_momentum[nonzero_mask]
if np.any(radii_nz <= 0):
log_message("Warning: Zero or negative radius found after mask?", level="warning")
return None, None, None, None
with np.errstate(divide='ignore', invalid='ignore'):
tangential_velocity_sim = angular_momentum_nz / radii_nz
total_velocity_sim = np.sqrt(tangential_velocity_sim**2 + radial_velocity_nz**2)
kmsec_to_kpcmyr = 1.02271e-3
total_velocity_kms = total_velocity_sim / kmsec_to_kpcmyr
valid_mask = np.isfinite(radii_nz) & np.isfinite(total_velocity_kms) & \
np.isfinite(radial_velocity_nz) & np.isfinite(angular_momentum_nz) & \
(total_velocity_kms > 0) & (radii_nz > 0)
if not np.any(valid_mask):
log_message("Warning: No valid particles after filtering infinities/NaNs.", level="warning")
return None, None, None, None
final_radii = radii_nz[valid_mask]
final_vr_kms = (radial_velocity_nz[valid_mask]) / kmsec_to_kpcmyr
final_L = angular_momentum_nz[valid_mask]
final_vtotal_kms = total_velocity_kms[valid_mask]
log_message(f"Processed data: {len(final_radii)} valid particles.", level="debug")
return final_radii, final_vr_kms, final_L, final_vtotal_kms
[docs]
def plot_particles_histograms(suffix, progress_callback=None):
"""
Plots 2D histograms from the particles files.
• Reads the initial particles file: data/particles{suffix}.dat (expected 4 columns)
and plots a histogram with plt.hist2d.
• Reads the final particles file: data/particlesfinal{suffix}.dat (expected 4 columns);
computes a derived velocity value and then plots a 2D histogram.
Resulting images are saved to the "results" folder.
Parameters
----------
suffix : str
Suffix for input/output files.
progress_callback : callable, optional
Function to call after each plot is saved,
with the output file path as argument.
"""
# Constants for expected number of columns
ncol_particles_initial = 4
ncol_particles_final = 4
# Track plots we'll generate
output_files = []
# Start progress tracking for data loading (4 steps: check initial, load initial, check final, load final)
start_combined_progress("particles_histograms_data_loading", 4)
# ---------------------------------------
# Plot the initial particles file histogram
particles_file = f"data/particles{suffix}.dat"
# Update progress - checking initial file
update_combined_progress("particles_histograms_data_loading", "checking_initial_file")
if os.path.exists(particles_file):
# Use the appropriate column count for initial particles
logger.info(f"Loading initial particles data from {particles_file}")
# Update progress - loading initial data
update_combined_progress("particles_histograms_data_loading", "loading_initial_data")
data = safe_load_and_filter_bin(
particles_file, ncol_particles_initial, np.float32)
if data is None or data.shape[0] == 0:
logger.warning(f"No valid data in {particles_file}")
print_status(f"No valid data in {particles_file}")
else:
iradii = data[:, 0]
radialvelocities = data[:, 1]
ell = data[:, 2]
# Avoid division by zero by filtering out non-positive radii.
nonzero_mask = iradii > 0
iradii = iradii[nonzero_mask]
radialvelocities = radialvelocities[nonzero_mask]
ell = ell[nonzero_mask]
# Log data processing details
logger.info(f"Processing initial particles data, shape before filtering: {data.shape}")
logger.info(f"Non-zero radii: {np.sum(nonzero_mask)}/{len(nonzero_mask)}")
# Record original data size for debugging
global particles_original_count
particles_original_count = data.shape[0]
logger.info(f"Particles 2D initial histogram - Original data size: {particles_original_count}")
# Log original array size before filtering
logger.info(f"Original data array size (before filtering): {len(iradii)}")
# Compute the derived (total) velocity as given by:
# ivelocities = sqrt(ell^2 + (radialvelocities^2 * iradii^2)) / iradii / 1.02271e-3
ivelocities = np.sqrt(ell**2 + (radialvelocities**2)
* (iradii**2)) / iradii / 1.02271e-3
# Filter out extreme values for visualization
valid_mask = (ivelocities > 0) & (ivelocities < 1e8)
iradii = iradii[valid_mask]
ivelocities = ivelocities[valid_mask]
logger.info(f"Valid velocities after filtering: {np.sum(valid_mask)}/{len(valid_mask)}")
if len(iradii) == 0:
logger.warning("All initial radii/velocities invalid or empty.")
print_status("All initial radii/velocities invalid or empty.")
else:
plt.figure(figsize=(8, 6))
plt.hist2d(iradii, ivelocities, bins=250, range=[
[0, 250], [0, 320]], cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Initial Phase Space Distribution', fontsize=14)
plt.xlim(0, 250)
plt.ylim(0, 320)
output_file = f"results/2d_histogram_initial{suffix}.png"
# Save original data size before any filtering
original_size = data.shape[0]
# Calculate and store histogram statistics
hist_counts, _, _ = np.histogram2d(iradii, ivelocities, bins=250, range=[[0, 250], [0, 320]])
initial_max = np.max(hist_counts)
initial_nonzero = np.count_nonzero(hist_counts)
initial_mean = np.sum(hist_counts)/initial_nonzero if initial_nonzero > 0 else 0
initial_total_particles = len(iradii)
initial_histogram_sum = np.sum(hist_counts)
logger.info(f"Original particle count: {original_size}")
logger.info(f"Filtered particle count: {initial_total_particles}")
logger.info(f"Histogram sum: {initial_histogram_sum}")
# Log detailed statistics
logger.info(f"Initial histogram statistics: Max={initial_max}, Mean={initial_mean:.2f}, Non-Zero Bins={initial_nonzero}, Total Particles={initial_total_particles}")
plt.savefig(output_file, dpi=150)
plt.close()
output_files.append(output_file)
else:
logger.warning(f"File {particles_file} not found.")
print_status(f"File {particles_file} not found.")
# ---------------------------------------
# Plot final particles histogram
particlesfinal_file = f"data/particlesfinal{suffix}.dat"
# Update progress - checking final file
update_combined_progress("particles_histograms_data_loading", "checking_final_file")
if os.path.exists(particlesfinal_file):
logger.info(f"Loading final particles data from {particlesfinal_file}")
# Update progress - loading final data
update_combined_progress("particles_histograms_data_loading", "loading_final_data")
data = safe_load_and_filter_bin(
particlesfinal_file, ncol_particles_final, np.float32)
if data is None or data.shape[0] == 0:
logger.warning(f"No valid data in {particlesfinal_file}")
print_status(f"No valid data in {particlesfinal_file}")
else:
fradii = data[:, 0]
radialvelocities = data[:, 1]
ell = data[:, 2]
# Log data processing details
logger.info(f"Processing final particles data, shape before filtering: {data.shape}")
# Record original data size for debugging
global particles_final_original_count
particles_final_original_count = data.shape[0]
logger.info(f"Particles 2D final histogram - Original data size: {particles_final_original_count}")
# Avoid division by zero by filtering out non-positive radii.
nonzero_mask = fradii > 0
fradii = fradii[nonzero_mask]
radialvelocities = radialvelocities[nonzero_mask]
ell = ell[nonzero_mask]
logger.info(f"Non-zero radii: {np.sum(nonzero_mask)}/{len(nonzero_mask)}")
# Log original array size before filtering
logger.info(f"Original data array size (before filtering): {len(fradii)}")
# Compute the derived velocity as given by:
# fvelocities = sqrt(ell^2 + (radialvelocities^2 * fradii^2)) / fradii / 1.02271e-3
fvelocities = np.sqrt(
ell**2 + (radialvelocities**2) * (fradii**2)) / fradii / 1.02271e-3
# Filter out extreme values for visualization
valid_mask = (fvelocities > 0) & (fvelocities < 1e8)
fradii = fradii[valid_mask]
fvelocities = fvelocities[valid_mask]
logger.info(f"Valid velocities after filtering: {np.sum(valid_mask)}/{len(valid_mask)}")
if len(fradii) == 0:
logger.warning("All final radii/velocities invalid or empty.")
print_status("All final radii/velocities invalid or empty.")
else:
plt.figure(figsize=(8, 6))
plt.hist2d(fradii, fvelocities, bins=250, range=[
[0, 250], [0, 320]], cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Final Phase Space Distribution', fontsize=14)
plt.xlim(0, 250)
plt.ylim(0, 320)
output_file = f"results/2d_histogram_final{suffix}.png"
# Save original data size before any filtering
original_size = data.shape[0]
# Calculate and store histogram statistics
hist_counts, _, _ = np.histogram2d(fradii, fvelocities, bins=250, range=[[0, 250], [0, 320]])
final_max = np.max(hist_counts)
final_nonzero = np.count_nonzero(hist_counts)
final_mean = np.sum(hist_counts)/final_nonzero if final_nonzero > 0 else 0
final_total_particles = len(fradii)
final_histogram_sum = np.sum(hist_counts)
logger.info(f"Original particle count: {original_size}")
logger.info(f"Filtered particle count: {final_total_particles}")
logger.info(f"Histogram sum: {final_histogram_sum}")
# Log detailed statistics
logger.info(f"Final histogram statistics: Max={final_max}, Mean={final_mean:.2f}, Non-Zero Bins={final_nonzero}, Total Particles={final_total_particles}")
plt.savefig(output_file, dpi=150)
plt.close()
output_files.append(output_file)
else:
logger.warning(f"File {particlesfinal_file} not found.")
print_status(f"File {particlesfinal_file} not found.")
# Print histogram statistics to console for valid data
if len(output_files) > 0:
print_status("Particles histogram statistics:")
# Display initial histogram statistics
if 'initial_max' in locals() and 'initial_nonzero' in locals() and 'initial_mean' in locals():
print_status(f"Initial Data: Max. Value = {initial_max}, Non-Zero Bins = {initial_nonzero}, Mean Value (Non-Zero) = {initial_mean:.2f}")
if 'initial_total_particles' in locals() and 'initial_histogram_sum' in locals():
# Only log the particle counts and histogram total to the log file
if enable_logging:
logger.info(f"Initial Data: Original Size = {particles_original_count}, Particles = {initial_total_particles}, Histogram Total = {initial_histogram_sum:.0f}")
else:
# Find initial file if we don't have statistics
initial_file = next((f for f in output_files if "initial" in f), None)
if initial_file:
print_status(f"Initial Data: Generated histogram saved to {initial_file}")
# Display final histogram statistics
if 'final_max' in locals() and 'final_nonzero' in locals() and 'final_mean' in locals():
print_status(f"Final Data: Max. Value = {final_max}, Non-Zero Bins = {final_nonzero}, Mean Value (Non-Zero) = {final_mean:.2f}")
if 'final_total_particles' in locals() and 'final_histogram_sum' in locals():
# Only log the particle counts and histogram total to the log file
if enable_logging:
logger.info(f"Final Data: Original Size = {particles_final_original_count}, Particles = {final_total_particles}, Histogram Total = {final_histogram_sum:.0f}")
else:
# Find final file if we don't have statistics
final_file = next((f for f in output_files if "final" in f), None)
if final_file:
print_status(f"Final Data: Generated histogram saved to {final_file}")
# Add separator line after statistics
sys.stdout.write(get_separator_line(char='-') + "\n")
# Once data loading is complete, start plot saving progress if needed
if output_files:
# Start a new progress tracker for saving plots
start_combined_progress("particles_histograms_plots", len(output_files))
# Log generated plots with progress indication
if progress_callback:
# Use progress callback for unified progress tracking
for output_file in output_files:
progress_callback(output_file)
else:
# Fallback to standard progress tracking
for output_file in output_files:
update_combined_progress("particles_histograms_plots", output_file)
[docs]
def plot_nsphere_histograms(suffix, progress_callback=None):
"""
Plots 2D histograms from the nsphere.c-produced binary histogram files.
Expects each file to contain 3 columns and 40000 records (reshaped to 200 × 200).
• data/2d_hist_initial{suffix}.dat for the initial histogram.
• data/2d_hist_final{suffix}.dat for the final histogram.
The plots are created using plt.pcolormesh and saved into the "results" folder.
"""
ncols_hist = 3
hist_dtype = [np.float32, np.float32, np.int32]
# Track plots we'll generate
output_files = []
# Start progress tracking for data loading (4 steps: check initial, load initial, check final, load final)
start_combined_progress("nsphere_histograms_data_loading", 4)
# ---------------------------------------
# Plot nsphere initial histogram
hist_initial_file = f"data/2d_hist_initial{suffix}.dat"
# Update progress - checking initial file
update_combined_progress("nsphere_histograms_data_loading", "checking_initial_file")
if os.path.exists(hist_initial_file):
logger.info(f"Loading initial histogram data from {hist_initial_file}")
# Update progress - loading initial data
update_combined_progress("nsphere_histograms_data_loading", "loading_initial_histogram")
data_init = safe_load_and_filter_bin(
hist_initial_file, ncols_hist, hist_dtype)
if data_init is None:
logger.warning(f"No valid data in {hist_initial_file}")
print_status(f"No valid data in {hist_initial_file}")
else:
if data_init.shape[0] != 40000:
logger.warning(
f"Expected 40000 entries in {hist_initial_file}, got {data_init.shape[0]}")
print_status(
f"Warning: Expected 40000 entries in {hist_initial_file}, got {data_init.shape[0]}")
else:
try:
logger.info(f"Processing initial histogram data, shape: {data_init.shape}")
X_init = data_init[:, 0].reshape((200, 200))
Y_init = data_init[:, 1].reshape((200, 200))
C_init = data_init[:, 2].reshape((200, 200))
# Calculate statistics for histogram
hist_max = np.max(C_init)
hist_nonzero = np.count_nonzero(C_init)
hist_mean = np.sum(C_init)/hist_nonzero if hist_nonzero > 0 else 0
# Get total particles (sum of all counts)
hist_total_particles = np.sum(C_init)
# Record original data size for debugging
hist_original_particles = data_init.shape[0]
logger.info(f"NSphere initial histogram - Original data size: {hist_original_particles}, Histogram sum: {hist_total_particles:.0f}")
logger.info(f"Initial histogram statistics: Min={np.min(C_init)}, Max={hist_max}, Mean={hist_mean:.2f}, Non-Zero Bins={hist_nonzero}")
plt.figure(figsize=(8, 6))
plt.pcolormesh(X_init, Y_init, C_init,
shading='auto', cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Initial Phase Space Distribution (nsphere.c)', fontsize=14)
output_file = f"results/2d_hist_nsphere_initial{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
output_files.append(output_file)
except Exception as e:
logger.error(f"Error reshaping data from {hist_initial_file}: {e}")
print_status(f"Error reshaping data from {hist_initial_file}: {e}")
else:
logger.warning(f"File {hist_initial_file} not found.")
print_status(f"File {hist_initial_file} not found.")
# ---------------------------------------
# Plot nsphere final histogram
hist_final_file = f"data/2d_hist_final{suffix}.dat"
# Update progress - checking final file
update_combined_progress("nsphere_histograms_data_loading", "checking_final_file")
if os.path.exists(hist_final_file):
logger.info(f"Loading final histogram data from {hist_final_file}")
# Update progress - loading final data
update_combined_progress("nsphere_histograms_data_loading", "loading_final_histogram")
data_final = safe_load_and_filter_bin(
hist_final_file, ncols_hist, hist_dtype)
if data_final is None:
logger.warning(f"No valid data in {hist_final_file}")
print_status(f"No valid data in {hist_final_file}")
else:
if data_final.shape[0] != 40000:
logger.warning(
f"Expected 40000 entries in {hist_final_file}, got {data_final.shape[0]}")
print_status(
f"Warning: Expected 40000 entries in {hist_final_file}, got {data_final.shape[0]}")
else:
try:
logger.info(f"Processing final histogram data, shape: {data_final.shape}")
X_final = data_final[:, 0].reshape((200, 200))
Y_final = data_final[:, 1].reshape((200, 200))
C_final = data_final[:, 2].reshape((200, 200))
# Calculate statistics for histogram
# Use different variable names for final histogram to avoid overwriting initial values
final_max = np.max(C_final)
final_nonzero = np.count_nonzero(C_final)
final_mean = np.sum(C_final)/final_nonzero if final_nonzero > 0 else 0
# Get total particles (sum of all counts)
final_total_particles = np.sum(C_final)
# Record original data size for debugging
final_original_particles = data_final.shape[0]
logger.info(f"NSphere final histogram - Original data size: {final_original_particles}, Histogram sum: {final_total_particles:.0f}")
logger.info(f"Final histogram statistics: Min={np.min(C_final)}, Max={final_max}, Mean={final_mean:.2f}, Non-Zero Bins={final_nonzero}")
plt.figure(figsize=(8, 6))
plt.pcolormesh(X_final, Y_final, C_final,
shading='auto', cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Final Phase Space Distribution (nsphere.c)', fontsize=14)
output_file = f"results/2d_hist_nsphere_final{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
output_files.append(output_file)
except Exception as e:
logger.error(f"Error reshaping data from {hist_final_file}: {e}")
print_status(f"Error reshaping data from {hist_final_file}: {e}")
else:
logger.warning(f"File {hist_final_file} not found.")
print_status(f"File {hist_final_file} not found.")
# Print histogram statistics to console for valid data
if 'hist_max' in locals() and 'hist_nonzero' in locals() and 'hist_mean' in locals():
print_status("NSphere histogram statistics:")
print_status(f"Initial Data: Max. Value = {hist_max}, Non-Zero Bins = {hist_nonzero}, Mean Value (Non-Zero) = {hist_mean:.2f}")
# Only log the particle counts to the log file
if enable_logging and 'hist_total_particles' in locals() and 'hist_original_particles' in locals():
logger.info(f"NSphere Initial: Original Size = {hist_original_particles}, Total Particles = {hist_total_particles:.0f}")
if 'final_max' in locals() and 'final_nonzero' in locals() and 'final_mean' in locals():
print_status(f"Final Data: Max. Value = {final_max}, Non-Zero Bins = {final_nonzero}, Mean Value (Non-Zero) = {final_mean:.2f}")
# Only log the particle counts to the log file
if enable_logging and 'final_total_particles' in locals() and 'final_original_particles' in locals():
logger.info(f"NSphere Final: Original Size = {final_original_particles}, Total Particles = {final_total_particles:.0f}")
# Add separator line after statistics
sys.stdout.write(get_separator_line(char='-') + "\n")
# The particles histogram statistics are now displayed in the main histogram statistics section
# Once data loading is complete, start plot saving progress if needed
if output_files:
# Start a new progress tracker for saving plots
start_combined_progress("nsphere_histograms_plots", len(output_files))
# Log generated plots with progress indication
if progress_callback:
# Use progress callback for unified progress tracking
for output_file in output_files:
progress_callback(output_file)
else:
# Fallback to standard progress tracking
for output_file in output_files:
update_combined_progress("nsphere_histograms_plots", output_file)
# Storage for tracked particle IDs used in energy analysis
tracked_ids_for_energy = []
# Helper function to process each file with tracked particle IDs for energy-time plot
[docs]
def process_sorted_energy_file(task_data):
"""
Process an unsorted rank file for energy-time plot, extracting data for specific tracked particle IDs.
Parameters
----------
task_data : tuple
A tuple containing (fname, local_suffix) where:
- fname is the path to the unsorted Rank data file
- local_suffix is the suffix for file identification
Returns
-------
tuple or None
(snapshot_number, energy_data) if successful or None if there was an error.
"""
# Unpack arguments
fname, local_suffix = task_data
global ncol_Rank_Mass_Rad_VRad_unsorted
# Extract the timestep from the filename
m = re.search(r'Rank_Mass_Rad_VRad_unsorted_t(\d+)' + re.escape(local_suffix) + r'\.dat$', fname)
if not m:
logger.warning(f"Regex failed for sorted energy file {fname} with suffix '{local_suffix}'")
return None
snap = int(m.group(1))
# Reload tracked IDs inside the worker
lowest_radius_ids_file = f"data/lowest_radius_ids{local_suffix}.dat"
id_data = safe_load_and_filter_bin(lowest_radius_ids_file, ncols=2, dtype=[np.int32, np.float32])
if id_data is None or len(id_data) < 1:
logger.warning(f"Worker failed to load IDs from {lowest_radius_ids_file}")
return None
max_particles = min(10, id_data.shape[0])
local_tracked_ids = id_data[:max_particles, 0].astype(int)
# Use helper function to load data for specific particle IDs
tracked_energy_data = safe_load_particle_ids_bin(
fname, ncols=ncol_Rank_Mass_Rad_VRad_unsorted, particle_ids=local_tracked_ids, dtype=np.float32
)
if tracked_energy_data is not None and tracked_energy_data.shape[0] > 0:
return (snap, tracked_energy_data)
else:
return None
# Helper function for phase space plotting
[docs]
def phase_space_process_rank_file(fname):
"""
Process a single sorted Rank file for phase space plotting.
Extracts snapshot number and reads data using the appropriate
data types and safe loader.
Parameters
----------
fname : str
The filename of the sorted Rank file to process.
Returns
-------
tuple or None
(snapshot_number, data) if successful, otherwise None.
'data' is a structured numpy array containing the file contents.
"""
# Extract the timestep from the filename.
mo = re.search(r'Rank_Mass_Rad_VRad_sorted_t(\d+)' + re.escape(suffix) + r'\.dat$', fname)
if not mo:
return None
snap = int(mo.group(1))
# Constants for the expected number of columns
ncol_Rank_Mass_Rad_VRad_sorted = 8
# Read the sorted file using safe loader
data = safe_load_and_filter_bin(fname, ncol_Rank_Mass_Rad_VRad_sorted, dtype=[
np.int32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32])
if data is None:
return None
return (snap, data)
[docs]
def process_unsorted_rank_file(task_data_with_suffix):
"""
Process an unsorted rank file and return the snapshot number and energy values.
Parameters
----------
task_data_with_suffix : tuple
A tuple containing (fname, project_root_path, local_suffix) where:
- fname is the path to the unsorted Rank data file
- project_root_path is the explicit path to the project root directory
- local_suffix is the suffix for file identification
Returns
-------
tuple or None
(snapshot_number, energy_values) if successful or None if there was an error.
"""
# Unpack arguments
fname, project_root_path, local_suffix = task_data_with_suffix
# Extract the timestep from the filename
mo = re.search(r'Rank_Mass_Rad_VRad_unsorted_t(\d+)' + re.escape(local_suffix) + r'\.dat$', fname)
if not mo:
logger.warning(f"Regex failed for unsorted file {fname} with suffix '{local_suffix}'")
return None
snap = int(mo.group(1))
# Constants for the expected number of columns
ncol = ncol_Rank_Mass_Rad_VRad_unsorted # Should be 7
# --- Start: Optimized Seek-Based Read Logic ---
# Get IDs for the first 10 particles
num_particles_to_get = 10
particle_ids_to_get = list(range(num_particles_to_get))
# Use safe_load_particle_ids_bin which uses seek for basic dtypes
# Pass np.float32 to ensure it takes the optimized path
data_subset = safe_load_particle_ids_bin(
fname,
ncols=ncol,
particle_ids=particle_ids_to_get,
dtype=np.float32 # IMPORTANT: Use basic dtype for seek path
)
if data_subset is None or data_subset.shape[0] == 0:
return None
# Extract Energy (column 5) from the loaded subset
# Handle cases where fewer than 10 particles were actually found (rows might contain NaN)
# We only want valid energy values
valid_energy_mask = ~np.isnan(data_subset[:, 5])
E_values = data_subset[valid_energy_mask, 5]
if E_values.size == 0:
return None
# --- End: Optimized Seek-Based Read Logic ---
return (snap, E_values)
[docs]
def process_sorted_energy(fname):
"""
Process a rank file for energy plots. This function extracts energy values for
specific particles based on their IDs.
Parameters
----------
fname : str
The filename to process.
Returns
-------
tuple or None
(snapshot_number, energy_values) if successful or None if there was an error.
"""
# Extract the timestep from the filename
mo = re.search(r'Rank_Mass_Rad_VRad_sorted_t(\d+)' + re.escape(suffix) + r'\.dat$', fname)
if not mo:
return None
snap = int(mo.group(1))
# Load the particle IDs to track (particles with lowest initial radius)
lowest_radius_ids_file = f"data/lowest_radius_ids{suffix}.dat"
# Load the lowest_radius_ids file using the binary file handler
id_data = safe_load_and_filter_bin(lowest_radius_ids_file, ncols=2, dtype=[np.int32, np.float32])
if id_data is None or len(id_data) < 1:
print_status(f"Failed to load particle IDs from {lowest_radius_ids_file}")
return None
# Extract the first 10 particles to track (or fewer if less available)
max_particles = min(10, id_data.shape[0])
tracked_ids = id_data[:max_particles, 0].astype(int) # Convert to integers for use as indices
# Now find the corresponding unsorted file for this snapshot
unsorted_fname = f"data/Rank_Mass_Rad_VRad_unsorted_t{snap:05d}{suffix}.dat"
if not os.path.exists(unsorted_fname):
return None
# Use helper function to load data for specific particle IDs
tracked_energy_data = safe_load_particle_ids_bin(
unsorted_fname, ncols=ncol_Rank_Mass_Rad_VRad_unsorted, particle_ids=tracked_ids, dtype=np.float32
)
if tracked_energy_data is None or tracked_energy_data.shape[0] == 0:
return None
# Extract energy values (column 5)
energy_values = tracked_energy_data[:, 5]
return (snap, energy_values)
[docs]
def process_rank_file(fname):
"""
Process a rank file and return the snapshot number and decimated data.
Parameters
----------
fname : str
The filename to process.
Returns
-------
tuple or None
(snapshot_number, decimated_data) if successful or None if there was an error.
"""
# Extract the timestep from the filename
mo = re.search(r'Rank_Mass_Rad_VRad_sorted_t(\d+)' + re.escape(suffix) + r'\.dat$', fname)
if not mo:
return None
snap = int(mo.group(1))
# Read the file using the safe loader
data = safe_load_and_filter_bin(fname, ncol_Rank_Mass_Rad_VRad_sorted,
dtype=[np.int32, np.float32, np.float32, np.float32,
np.float32, np.float32, np.float32, np.float32])
if data is None or data.shape[0] == 0:
print_status(f"No data to process in {fname}")
return None
# Process the data
decimated = data
del data # Free full-resolution data
gc.collect()
return (snap, decimated)
[docs]
def process_rank_file_for_1d_anim(task_data_with_suffix):
"""
Processes a single sorted Rank snapshot file for 1D animations.
Now accepts suffix explicitly.
Loads only the required columns (Mass, Radius, Psi, Density), sorts by
radius, filters invalid data, fits linear splines to Mass, Psi and Density,
performs downsampling to a common radius grid, and returns the processed data
ready for animation.
Parameters
----------
task_data_with_suffix : tuple
A tuple containing (fname, project_root_path, local_suffix) where:
- fname is the path to the sorted Rank data file
- project_root_path is the explicit path to the project root directory
- local_suffix is the suffix for file identification
Returns
-------
tuple or None
On success, returns a tuple:
(snap_num, sampled_radii, sampled_mass, sampled_density, sampled_psi)
where arrays have the downsampled length.
Returns None if loading, processing, or spline fitting fails.
"""
# Unpack arguments including the suffix
fname, project_root_path, local_suffix = task_data_with_suffix
# No longer need: global suffix
global ncol_Rank_Mass_Rad_VRad_sorted # Keep this global if needed elsewhere in function
# Define the dtype list for the *entire* row of the sorted file
sorted_rank_dtype_list = [np.int32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32]
# Columns needed: Radius (2, sort key), Mass (1), Psi (4), Density (7)
# Indices passed to loader must be sorted for some selection methods
cols_to_load_indices = sorted([1, 2, 4, 7])
# Map back to original meaning after loading
load_idx_map = {1: 0, 2: 1, 4: 2, 7: 3} # Original Col -> Index in loaded data
radius_load_idx = load_idx_map[2] # Index of Radius within loaded data
# Extract snapshot number from filename
snap_num = _extract_Rank_snapnum(fname, local_suffix)
if snap_num == 999999999:
logger.warning(f"Could not extract valid snapshot number from {fname} using suffix '{local_suffix}' for 1D anim. Skipping.")
return None
try:
# Load only the necessary columns
loaded_cols = load_specific_columns_bin(
fname,
ncols_total=ncol_Rank_Mass_Rad_VRad_sorted, # Total cols in file (8)
cols_to_load=cols_to_load_indices,
dtype_list=sorted_rank_dtype_list
)
if loaded_cols is None or len(loaded_cols) != len(cols_to_load_indices):
logger.warning(f"Failed to load required columns for snap {snap_num} from {fname}")
return None
# Combine into a temporary array for sorting/filtering
# Order: Mass, Radius, Psi, Density (based on sorted cols_to_load_indices)
temp_data = np.column_stack(loaded_cols)
del loaded_cols # Free memory
gc.collect()
# Sort by Radius (which is at index radius_load_idx)
sort_indices = np.argsort(temp_data[:, radius_load_idx])
temp_data_sorted = temp_data[sort_indices, :]
del temp_data, sort_indices # Free memory
gc.collect()
# Separate columns after sorting
radius_sorted = temp_data_sorted[:, radius_load_idx]
mass_sorted = temp_data_sorted[:, load_idx_map[1]]
psi_sorted = temp_data_sorted[:, load_idx_map[4]]
density_sorted = temp_data_sorted[:, load_idx_map[7]]
del temp_data_sorted # Free memory
gc.collect()
# Filter out duplicates in radius (needed for spline) and non-finite values
unique_radii, unique_indices = np.unique(radius_sorted, return_index=True)
# Keep only data corresponding to unique radii
radius_unique = unique_radii
mass_unique = mass_sorted[unique_indices]
psi_unique = psi_sorted[unique_indices]
density_unique = density_sorted[unique_indices]
# Filter for finite values in all arrays
finite_mask = ( np.isfinite(radius_unique) &
np.isfinite(mass_unique) &
np.isfinite(psi_unique) &
np.isfinite(density_unique) )
if not np.any(finite_mask):
logger.warning(f"No finite data after filtering for snap {snap_num}")
return None
radius_final = radius_unique[finite_mask]
mass_final = mass_unique[finite_mask]
psi_final = psi_unique[finite_mask]
density_final = density_unique[finite_mask]
if len(radius_final) < 10: # Need enough points for spline
logger.warning(f"Too few points ({len(radius_final)}) for spline fitting snap {snap_num}")
return None
# --- Spline Fitting ---
try:
# Use k=1 for linear interpolation, s=0 to force interpolation
spl_mass = UnivariateSpline(radius_final, mass_final, k=1, s=0, ext='raise')
spl_psi = UnivariateSpline(radius_final, psi_final, k=1, s=0, ext='raise')
spl_density = UnivariateSpline(radius_final, density_final, k=1, s=0, ext='raise')
except Exception as spline_e:
logger.error(f"Spline fitting failed for snap {snap_num}: {spline_e}")
return None
# --- Downsampling ---
r_min, r_max = radius_final.min(), radius_final.max()
# Ensure min/max reasonable
if r_min <= 0 or r_max <= r_min:
logger.warning(f"Invalid radius range [{r_min}, {r_max}] for snap {snap_num}")
return None
n_particles_approx = len(radius_final) # Use length after filtering
n_samples = max(int(n_particles_approx / 100), 10000) # Downsample factor 100, min 10k points
sampled_radii = np.linspace(r_min, r_max, n_samples)
# --- Evaluate Splines ---
sampled_mass = spl_mass(sampled_radii)
sampled_psi = spl_psi(sampled_radii)
sampled_density = spl_density(sampled_radii)
# Return the downsampled data
return (snap_num, sampled_radii, sampled_mass, sampled_density, sampled_psi)
except Exception as e:
logger.error(f"Error processing {fname} for 1D anim: {e}")
logger.error(traceback.format_exc())
return None
finally:
gc.collect() # Cleanup worker memory
[docs]
def preprocess_phase_space_file(rank_file_data_with_suffix):
"""
Preprocess a single Rank file for phase space animation.
Parameters
----------
rank_file_data_with_suffix : tuple
Tuple containing (rank_file, placeholder, ncol, max_r_all, max_v_all, nbins,
kmsec_to_kpcmyr, project_root_path, local_suffix).
Returns
-------
tuple or None
(snap_num, H, frame_vmax) if successful or None if processing failed.
"""
# Unpack arguments including the suffix
rank_file, _, ncol, max_r_all, max_v_all, nbins, kmsec_to_kpcmyr, project_root_path, local_suffix = rank_file_data_with_suffix
# Extract snapshot number for labeling
snap_num = _extract_Rank_snapnum(rank_file, local_suffix)
if snap_num == 999999999:
logger.warning(f"Could not extract valid snapshot number from {rank_file} using suffix '{local_suffix}' for phase space. Skipping.")
return None
# Load the snapshot data
data = safe_load_and_filter_bin(
rank_file,
ncol,
dtype=[np.int32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32]
)
if data is None or data.shape[0] == 0:
return None
# Extract radius and velocity data from sorted data with columns:
# Rank(0), Mass(1), Radius(2), Vrad(3), Psi(4), Energy(5), L(6), Density(7)
radii = data[:, 2] # Column index 2 holds radius (r) - simulation units
radial_velocity = data[:, 3] # Column index 3 holds radial velocity (v_r) - simulation units
angular_momentum = data[:, 6] # Column index 6 holds angular momentum (L) - simulation units
# Filter out non-positive radii to avoid division by zero
nonzero_mask = radii > 0
radii = radii[nonzero_mask]
radial_velocity = radial_velocity[nonzero_mask]
angular_momentum = angular_momentum[nonzero_mask]
# Skip if not enough data points
if len(radii) < 10:
return None
# Compute the total velocity - ALL IN SIMULATION UNITS
# v_total_sim = sqrt((L/r)^2 + v_r^2)
tangential_velocity_sim = angular_momentum / radii # Tangential velocity in simulation units
total_velocity_sim = np.sqrt(tangential_velocity_sim**2 + radial_velocity**2) # Total velocity in simulation units
# Convert to km/s at the end
total_velocity = total_velocity_sim / kmsec_to_kpcmyr # Convert to km/s
# Filter out extreme values for visualization
valid_mask = (total_velocity > 0) & (total_velocity < 1e8)
radii = radii[valid_mask]
total_velocity = total_velocity[valid_mask]
# Skip if not enough valid data points
if len(radii) < 10:
return None
# Create the 2D histogram
H, xedges, yedges = np.histogram2d(
radii,
total_velocity,
bins=[nbins, nbins],
range=[[0, max_r_all], [0, max_v_all]]
)
# Calculate a reasonable vmax for the frame
frame_vmax = calculate_reasonable_vmax(H)
# Clean up memory before returning
del radii, total_velocity, radial_velocity, angular_momentum, data
gc.collect()
return (snap_num, H, frame_vmax)
[docs]
def generate_phase_space_animation(suffix, fps=10):
"""
Creates an animation showing the evolution of the phase space distribution over time.
Parameters
----------
suffix : str
Suffix for input/output files.
fps : int, optional
Frames per second for the animation, by default 10.
Returns
-------
bool
True if animation was created successfully, False otherwise.
Notes
-----
Uses parallel processing for preprocessing histogram data and rendering frames.
Saves the animation incrementally using `imageio.get_writer` to reduce
peak memory usage compared to collecting all frames first.
Requires imageio v2 (`pip install imageio==2.*`).
"""
# Find all Rank sorted files with the exact suffix
rank_files = glob.glob(f"data/Rank_Mass_Rad_VRad_sorted_t*{suffix}.dat")
# Filter to keep ONLY files that match the exact pattern:
# data/Rank_Mass_Rad_VRad_sorted_t00001_40000_1001_5.dat
# without any extra characters between "t00001" and the suffix
correct_pattern = re.compile(r'data/Rank_Mass_Rad_VRad_sorted_t\d+' + re.escape(suffix) + r'\.dat$')
rank_files = [f for f in rank_files if correct_pattern.match(f)]
# Debug output
global enable_logging
if enable_logging:
log_message(f"Found {len(rank_files)} rank files with pattern: data/Rank_Mass_Rad_VRad_sorted_t*{suffix}.dat (after filtering)")
if not rank_files:
print_status("No Rank sorted files found. Cannot create phase space animation.")
return False
# Sort files by snapshot number to ensure correct animation sequence
rank_files.sort(key=lambda fname: _extract_Rank_snapnum(fname, suffix))
# Log the exact file count (to log file only)
file_count = len(rank_files)
log_message(f"Found {file_count} rank files for phase space animation.")
# For debugging - list the first few files
if file_count > 0:
log_message(f"Example files (first 3): {', '.join(rank_files[:3])}")
# Constants
ncol_Rank_Mass_Rad_VRad_sorted = 8
kmsec_to_kpcmyr = 1.02271e-3 # Conversion factor
# Set up histogram parameters
max_r_all = 250.0
max_v_all = 320.0
nbins = 200
# Prepare data for parallel preprocessing
# Log to file only, keep console output minimal
log_message(f"Processing {len(rank_files)} files for phase space animation...")
# Prepare data for parallel preprocessing, ensuring suffix is explicitly passed
preprocess_args_with_suffix = [(
rank_file,
None, # Placeholder for original suffix position
ncol_Rank_Mass_Rad_VRad_sorted,
max_r_all,
max_v_all,
nbins,
kmsec_to_kpcmyr,
PROJECT_ROOT,
suffix # Add the suffix at the end of the tuple
) for rank_file in rank_files]
# Process all files in parallel
with mp.Pool() as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('preproc_phase', term_width)
counter_tqdm = tqdm(
total=len(rank_files),
desc=selected_desc, # Use dynamic description
unit="file",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True,
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=len(rank_files),
position=1,
leave=True,
dynamic_ncols=True,
bar_format="{bar} {percentage:3.1f}%",
ascii=False
)
# Process the files with a custom callback to update both progress bars
results = []
# Use the modified arguments list
for result in pool.imap(preprocess_phase_space_file, preprocess_args_with_suffix):
results.append(result)
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Add delay after progress bars if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
# Log to file only
log_message(f"Processed {len([r for r in results if r is not None])} phase space files successfully")
# Filter out None results and calculate global vmax
frame_data_list = [r for r in results if r is not None]
if not frame_data_list:
print_status("No valid frames found. Cannot create phase space animation.")
return False
# Find the global maximum value for consistent color scaling
initial_vmax = max(vmax for _, _, vmax in frame_data_list)
# Extract tfinal_factor from suffix if possible
# Format is typically _[file_tag]_npts_Ntimes_tfinal_factor
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Update vmax for all frames to use the consistent value and add tfinal_factor
# Update vmax for all frames and add total_frames to each tuple for worker access
total_frames = len(frame_data_list)
frame_data_list = [(snap_num, H, initial_vmax, tfinal_factor, total_frames) for snap_num, H, _ in frame_data_list]
# Render frames in parallel
# Still set the global for backward compatibility, though workers won't use it
global total_snapshots
total_snapshots = total_frames
# Log detailed info to file only
log_message(f"Generating phase space frames for {total_frames} snapshots...")
# Set up imageio writer
phase_anim_output = f"results/Phase_Space_Animation{suffix}.gif"
# Use seconds per frame for imageio v2 duration
frame_duration_sec_v2 = 1.0 / fps # fps is frames per second, duration is seconds per frame
try:
# Use mode='I' for multiple images, loop=0 for infinite loop
writer = imageio.get_writer(
phase_anim_output,
format='GIF-PIL', # Explicitly use Pillow
mode='I',
# quantizer='nq', # Temporarily removed
palettesize=256, # Ensure full palette
duration=frame_duration_sec_v2,
loop=0
)
except Exception as e:
print_status(f"Error creating GIF writer: {e}")
return False # Cannot proceed without writer
# Use multiple processes to render frames
with mp.Pool() as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('render_phase', term_width)
counter_tqdm = tqdm(
total=total_frames,
desc=selected_desc, # Use dynamic description
unit="frame",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True,
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=total_frames,
position=1,
leave=True,
dynamic_ncols=True,
bar_format="{bar} {percentage:3.1f}%",
ascii=False
)
# Process the frames with a custom callback to update both progress bars
frame_count = 0
for frame_image in pool.imap(render_phase_frame, frame_data_list):
if frame_image is not None:
try:
writer.append_data(frame_image) # Append to writer
frame_count += 1
except Exception as e:
log_message(f"Error appending frame {frame_count+1}: {e}", level="error")
# Decide whether to break or continue
break
# Update progress bars
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Close the writer
try:
writer.close()
except Exception as e:
log_message(f"Error closing GIF writer: {e}", level="error")
# Add delay after progress bars if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
# Log to file only
log_message(f"Generated and saved {frame_count}/{total_frames} phase space frames successfully to {phase_anim_output}")
sys.stdout.write(get_separator_line(char='-') + "\n")
if frame_count == total_frames:
print_status(f"Animation saved: {get_file_prefix(phase_anim_output)}")
# Clean up
del frame_data_list
gc.collect()
return True
else:
print_status(f"Animation partially saved ({frame_count}/{total_frames}): {get_file_prefix(phase_anim_output)}")
# Clean up
del frame_data_list
gc.collect()
return False
[docs]
def generate_all_1D_animations(suffix, duration):
"""
Generate all three 1D profile animations (mass, density, psi) with optimized parallel processing.
Parameters
----------
suffix : str
Suffix for input/output files.
duration : float
Duration of each frame in milliseconds.
Notes
-----
This function runs each animation in sequence to ensure clean console output,
but each animation internally uses parallel processing for generating and encoding
frames. This gets the best performance while maintaining readable output.
"""
# Generate 1D profile animations sequentially for clean console output
# Each animation creation function internally uses parallel processing for frames
animations = [
("mass", create_mass_animation),
("density", create_density_animation),
("psi", create_psi_animation)
]
# Track animation success
success_count = 0
# Process each animation in sequence
for name, create_func in animations:
try:
log_message(f"Starting {name} animation generation")
create_func(suffix, duration)
success_count += 1
log_message(f"Completed {name} animation successfully")
except Exception as e:
log_message(f"Error generating {name} animation: {str(e)}", level="error")
print_status(f"Error generating {name} animation: {str(e)}")
# Continue with the next animation despite errors
# Return success status
return success_count == len(animations)
def _extract_Rank_snapnum(fname, suffix):
"""
Extract snapshot number from a rank file name.
Parameters
----------
fname : str
The filename.
suffix : str
The suffix to remove from the filename.
Returns
-------
int
The snapshot number.
"""
mo = re.search(r'Rank_Mass_Rad_VRad_sorted_t(\d+)' + re.escape(suffix) + r'\.dat$', fname)
if mo:
return int(mo.group(1))
return 999999999
[docs]
def calculate_reasonable_vmax(H):
"""
Calculate an appropriate maximum value for colorbar scaling based on histogram data.
Parameters
----------
H : ndarray
The 2D histogram data array.
Returns
-------
float
A reasonable maximum value for the colorbar, based on the 99.5th percentile
of non-zero values, rounded to a visually pleasing number.
Notes
-----
This function helps avoid having too much empty space in the colorbar while still
maintaining a consistent scale across different frames in animations.
"""
# Get non-zero values from the histogram
nonzero_values = H[H > 0]
if len(nonzero_values) == 0:
return 100 # Default if no non-zero values
# Calculate the 99.5th percentile of non-zero values
vmax = np.percentile(nonzero_values, 99.5)
# Round up to a nice number
if vmax < 10:
vmax = 10
elif vmax < 100:
vmax = np.ceil(vmax / 10) * 10
else:
vmax = np.ceil(vmax / 50) * 50
return vmax
[docs]
def render_phase_frame(frame_data):
"""
Renders a single frame for the phase space animation.
This function needs to be at the module level for multiprocessing to work.
Parameters
----------
frame_data : tuple
A tuple containing (snap_num, H, vmax, tfinal_factor, total_snapshots)
snap_num : int
The snapshot number for labeling.
H : ndarray
The 2D histogram data.
vmax : float
The maximum value for the colorbar (consistent across all frames).
tfinal_factor : int
Factor relating simulation time steps to dynamical times.
total_snapshots : int
Total number of frames/snapshots for calculating normalized time.
Returns
-------
numpy.ndarray
Image data for the rendered frame.
"""
# Unpack the data including total_snapshots
snap_num, H, vmax, tfinal_factor, total_snapshots = frame_data
# Set up histogram parameters
max_r_all = 250.0
max_v_all = 320.0
nbins = H.shape[0] # Assuming square histogram
# Recreate the meshgrid for pcolormesh
xedges = np.linspace(0, max_r_all, nbins + 1)
yedges = np.linspace(0, max_v_all, nbins + 1)
X, Y = np.meshgrid(xedges[:-1], yedges[:-1])
C = H.T # Transpose for correct orientation
# Create a figure for this frame
fig = plt.figure(figsize=(10, 8))
ax = plt.gca()
# Plot the phase space - Get the colormap from this call to use consistently
cmap = plt.cm.viridis # Default colormap
pcm = plt.pcolormesh(X, Y, C, shading='auto', cmap=cmap, vmin=0, vmax=vmax)
cbar = plt.colorbar(pcm)
cbar.set_label('Counts', fontsize=12)
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title('Phase Space Distribution Evolution', fontsize=14, pad=20)
plt.xlim(0, max_r_all)
plt.ylim(0, max_v_all)
# Calculate time in dynamical times using the tfinal_factor and passed total_snapshots
if total_snapshots and total_snapshots > 1:
normalized_snap = snap_num / (total_snapshots - 1)
else:
# Approximate by assuming snapshots range from 0-100
# (snapshots often use a zero-padded 5-digit format with max around 00100)
normalized_snap = min(1.0, snap_num / 100)
# Scale by tfinal_factor to get time in dynamical units
t_dyn_fraction = normalized_snap * tfinal_factor
# Add text in upper right corner showing time in dynamical times
# Use the same colormap as the plot for consistency
# Get colors directly from the current colormap
# Map min value to background (usually dark blue/purple in viridis)
# Map max value to text (usually yellow in viridis)
bg_color = pcm.cmap(0.0) # Background color = min value in colormap
text_color = pcm.cmap(1.0) # Text color = max value in colormap
ax.text(0.98, 0.95, f"$t = {t_dyn_fraction:.2f}\\,t_{{\\rm dyn}}$",
transform=ax.transAxes, ha='right', va='top', fontsize=12,
color=text_color,
bbox=dict(facecolor=bg_color, alpha=1.0, edgecolor='none', pad=3))
# Convert figure to image in memory
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100)
buf.seek(0)
# Clean up the figure to avoid memory leaks
plt.close()
# Return the image
return imageio.imread(buf)
[docs]
def generate_initial_phase_histogram(suffix):
"""
Creates a phase space histogram from the initial particles.dat file.
This represents the true initial distribution.
"""
# Constants for expected number of columns
ncol_particles_initial = 4
# Path to the initial particles file
particles_file = f"data/particles{suffix}.dat"
if not os.path.exists(particles_file):
print_status(f"Initial particles file {particles_file} not found")
return False
# Start loading the initial data with timing information
# Log to file only
logger.info(f"Loading initial phase space histogram: {get_file_prefix(particles_file)}")
start_time = time.time()
# Start progress tracking for data loading
start_combined_progress("phase_space_data_loading", 1)
# Update progress - loading initial data (use filename prefix format)
update_combined_progress("phase_space_data_loading", "particles_data")
# Load the data
data = safe_load_and_filter_bin(
particles_file, ncol_particles_initial, np.float32)
elapsed = time.time() - start_time
if data is None or data.shape[0] == 0:
print_status(f"No valid data in {particles_file}")
return False
# Log to file only
logger.info(f"Initial histogram loaded successfully [{elapsed:.2f}s]")
# Extract data columns - using first 3 relevant columns
iradii = data[:, 0]
radialvelocities = data[:, 1]
ell = data[:, 2]
# Avoid division by zero by filtering out non-positive radii
nonzero_mask = iradii > 0
iradii = iradii[nonzero_mask]
radialvelocities = radialvelocities[nonzero_mask]
ell = ell[nonzero_mask]
# Compute the derived (total) velocity - ALL IN SIMULATION UNITS
# v_total_sim = sqrt((L/r)^2 + v_r^2)
tangential_velocity_sim = ell / iradii # Tangential velocity in simulation units
total_velocity_sim = np.sqrt(tangential_velocity_sim**2 + radialvelocities**2) # Total velocity in simulation units
# Convert to km/s ONLY at the end
kmsec_to_kpcmyr = 1.02271e-3 # Conversion factor
ivelocities = total_velocity_sim / kmsec_to_kpcmyr # Convert to km/s
# Filter out invalid values
valid_mask = (iradii >= 0) & (iradii < 1e8) & (ivelocities > 0) & (ivelocities < 1e8)
iradii = iradii[valid_mask]
ivelocities = ivelocities[valid_mask]
if len(iradii) == 0:
print_status("All initial radii/velocities invalid or empty.")
return False
# Log statistics to file only
logger.info(f"Velocity range: {np.min(ivelocities)} to {np.max(ivelocities)} km/s")
logger.info(f"Radius range: {np.min(iradii)} to {np.max(iradii)} kpc")
# Set up output path
output_file = f"results/phase_space_initial{suffix}.png"
# Start progress tracking for plot saving
start_combined_progress("phase_space_plots", 1)
# Set up histogram parameters
max_r_all = 250.0
max_v_all = 320.0
nbins = 200
# Create the 2D histogram
H, xedges, yedges = np.histogram2d(
iradii,
ivelocities,
bins=[nbins, nbins],
range=[[0, max_r_all], [0, max_v_all]]
)
# Create meshgrid for pcolormesh
X, Y = np.meshgrid(xedges[:-1], yedges[:-1])
C = H.T # Transpose for correct orientation
plt.figure(figsize=(8, 6))
# Use pcolormesh with shading='auto' like in nsphere plot
# Calculate a reasonable vmax based on the histogram data
vmax = calculate_reasonable_vmax(H)
pcm = plt.pcolormesh(X, Y, C, shading='auto', cmap='viridis', vmin=0, vmax=vmax)
cbar = plt.colorbar(pcm)
cbar.set_label('Counts', fontsize=12)
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Initial Phase Space Distribution', fontsize=14)
# Set consistent limits
plt.xlim(0, max_r_all)
plt.ylim(0, max_v_all)
os.makedirs("results", exist_ok=True)
start_time = time.time()
plt.savefig(output_file, dpi=150)
plt.close()
elapsed = time.time() - start_time
# Update progress for saving plot
update_combined_progress("phase_space_plots", output_file)
# Log completion message to file only, not to console
log_message(f"Phase space histogram saved: {get_file_prefix(output_file)} [{elapsed:.2f}s]", "info")
logger.info(f"Initial phase space histogram saved to {output_file}")
return True
[docs]
def generate_comparison_plot(suffix):
"""
Creates a side-by-side comparison of the initial and last available snapshot phase space histograms.
"""
# Constants for expected number of columns
ncol_particles_initial = 4
ncol_Rank_Mass_Rad_VRad_sorted = 8
# Path to the initial particles file
particles_file = f"data/particles{suffix}.dat"
# Find the last available snapshot file
rank_files = glob.glob(f"data/Rank_Mass_Rad_VRad_sorted_t*{suffix}.dat")
if not rank_files:
print_status("No snapshot files found.")
return False
# Sort files by snapshot number and get the last one
rank_files.sort(key=lambda fname: _extract_Rank_snapnum(fname, suffix))
last_snap_file = rank_files[-1]
last_snap_num = _extract_Rank_snapnum(last_snap_file, suffix)
# Log the file search details to the log file, not the console
logger.info(f"Looking for initial file: {particles_file}")
logger.info(f"Looking for last snapshot file: {last_snap_file} (snapshot {last_snap_num})")
if not os.path.exists(particles_file):
print_status(f"Initial particles file {particles_file} not found.")
return False
if not os.path.exists(last_snap_file):
print_status(f"Last snapshot file {last_snap_file} not found.")
return False
# Start progress tracking for data loading (5 steps: 2 for loading, 3 for processing/histograms)
start_combined_progress("phase_space_loading", 5)
# Update progress - both files found, starting to load (use filename prefix)
update_combined_progress("phase_space_loading", "particles_data")
# Load the initial data
initial_data = safe_load_and_filter_bin(particles_file, ncol_particles_initial, np.float32)
if initial_data is None or initial_data.shape[0] == 0:
print_status(f"No valid data in {particles_file}")
return False
# Log the data shape details to the log file
logger.info(f"Initial data loaded, shape: {initial_data.shape}")
# Update progress - loaded initial data (use filename prefix)
update_combined_progress("phase_space_loading", "particles_done")
# Load the last snapshot data
data = safe_load_and_filter_bin(
last_snap_file,
ncol_Rank_Mass_Rad_VRad_sorted,
dtype=[np.int32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32, np.float32]
)
if data is None or data.shape[0] == 0:
print_status(f"No valid data in {last_snap_file}")
return False
# Log the data shape details to the log file
logger.info(f"Final snapshot data loaded, shape: {data.shape}")
# Update progress - loaded last snapshot data (use filename prefix)
update_combined_progress("phase_space_loading", "snapshot_data")
# Update progress for processing initial data - use file prefix format
update_combined_progress("phase_space_loading", "initial_data")
# Extract data columns - using first 3 relevant columns
iradii = initial_data[:, 0]
radialvelocities = initial_data[:, 1]
ell = initial_data[:, 2]
# Log detailed processing steps to log file
logger.info("Processing initial data - extracting columns and calculating velocities")
# Avoid division by zero by filtering out non-positive radii
nonzero_mask = iradii > 0
iradii = iradii[nonzero_mask]
radialvelocities = radialvelocities[nonzero_mask]
ell = ell[nonzero_mask]
# Compute the derived (total) velocity - ALL IN SIMULATION UNITS
# v_total_sim = sqrt((L/r)^2 + v_r^2)
tangential_velocity_sim = ell / iradii # Tangential velocity in simulation units
total_velocity_sim = np.sqrt(tangential_velocity_sim**2 + radialvelocities**2) # Total velocity in simulation units
# Convert to km/s ONLY at the end
kmsec_to_kpcmyr = 1.02271e-3 # Conversion factor
ivelocities = total_velocity_sim / kmsec_to_kpcmyr # Convert to km/s
# Filter out invalid values
valid_mask = (iradii >= 0) & (iradii < 1e8) & (ivelocities > 0) & (ivelocities < 1e8)
iradii = iradii[valid_mask]
ivelocities = ivelocities[valid_mask]
if len(iradii) == 0:
print_status("All initial radii/velocities invalid or empty.")
return False
# Update progress after processing initial data - use file prefix format
update_combined_progress("phase_space_loading", "final_data")
# Process last snapshot data
# Extract radius and velocity data from sorted data with columns:
# Rank(0), Mass(1), Radius(2), Vrad(3), Psi(4), Energy(5), L(6), Density(7)
radii = data[:, 2] # Column index 2 holds radius (r) - simulation units
radial_velocity = data[:, 3] # Column index 3 holds radial velocity (v_r) - simulation units
angular_momentum = data[:, 6] # Column index 6 holds angular momentum (L) - simulation units
logger.info("Processing final snapshot data - extracting columns and calculating velocities")
# Filter out non-positive radii to avoid division by zero
nonzero_mask = radii > 0
radii = radii[nonzero_mask]
radial_velocity = radial_velocity[nonzero_mask]
angular_momentum = angular_momentum[nonzero_mask]
# Compute the total velocity - ALL IN SIMULATION UNITS
# v_total_sim = sqrt((L/r)^2 + v_r^2)
tangential_velocity_sim = angular_momentum / radii # Tangential velocity in simulation units
total_velocity_sim = np.sqrt(tangential_velocity_sim**2 + radial_velocity**2) # Total velocity in simulation units
# Convert to km/s ONLY at the end
total_velocity = total_velocity_sim / kmsec_to_kpcmyr # Convert to km/s
# Filter out invalid values
valid_mask = (radii >= 0) & (radii < 1e8) & (total_velocity > 0) & (total_velocity < 1e8)
radii = radii[valid_mask]
total_velocity = total_velocity[valid_mask]
if len(radii) == 0:
print_status("All final snapshot radii/velocities invalid or empty.")
return False
# Update progress after processing both datasets - use a data file format for display
histogram_file = f"data/creating_histograms{suffix}.dat"
update_combined_progress("phase_space_loading", histogram_file)
# Set up histogram parameters
max_r_all = 250.0
max_v_all = 320.0
nbins = 200
# Create the 2D histograms
H_initial, xedges_initial, yedges_initial = np.histogram2d(
iradii,
ivelocities,
bins=[nbins, nbins],
range=[[0, max_r_all], [0, max_v_all]]
)
H_last_snap, xedges_last_snap, yedges_last_snap = np.histogram2d(
radii,
total_velocity,
bins=[nbins, nbins],
range=[[0, max_r_all], [0, max_v_all]]
)
# Print histogram statistics for comparison
print_status("Comparison of histogram statistics:")
print_status(f"Initial Data: Max. Value = {np.max(H_initial)}, Non-Zero Bins = {np.count_nonzero(H_initial)}, Mean Value (Non-Zero) = {np.sum(H_initial)/np.count_nonzero(H_initial):.2f}")
print_status(f"Final Data: Max. Value = {np.max(H_last_snap)}, Non-Zero Bins = {np.count_nonzero(H_last_snap)}, Mean Value (Non-Zero) = {np.sum(H_last_snap)/np.count_nonzero(H_last_snap):.2f}")
# Add separator line after statistics
sys.stdout.write(get_separator_line(char='-') + "\n")
# Create meshgrids for pcolormesh
X_initial, Y_initial = np.meshgrid(xedges_initial[:-1], yedges_initial[:-1])
C_initial = H_initial.T # Transpose for correct orientation
X_last_snap, Y_last_snap = np.meshgrid(xedges_last_snap[:-1], yedges_last_snap[:-1])
C_last_snap = H_last_snap.T # Transpose for correct orientation
# Create the side-by-side plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7), sharex=True, sharey=True)
# Set vmin and vmax to use a consistent color scale across both plots
vmax = calculate_reasonable_vmax(H_initial)
pcm1 = ax1.pcolormesh(X_initial, Y_initial, C_initial, shading='auto', cmap='viridis', vmin=0, vmax=vmax)
ax1.set_xlabel(r'$r$ (kpc)', fontsize=12)
ax1.set_ylabel(r'$v$ (km/s)', fontsize=12)
ax1.set_title(r'Initial Phase Space', fontsize=14)
ax1.set_xlim(0, max_r_all)
ax1.set_ylim(0, max_v_all)
pcm2 = ax2.pcolormesh(X_last_snap, Y_last_snap, C_last_snap, shading='auto', cmap='viridis', vmin=0, vmax=vmax)
ax2.set_xlabel(r'$r$ (kpc)', fontsize=12)
ax2.set_title(fr'Final Phase Space (Snapshot $t={last_snap_num}$)', fontsize=14)
ax2.set_xlim(0, max_r_all)
ax2.set_ylim(0, max_v_all)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) # [left, bottom, width, height]
cbar = fig.colorbar(pcm2, cax=cbar_ax)
cbar.set_label('Counts', fontsize=12)
fig.suptitle('Comparison of Phase Space Distributions: Initial vs Final', fontsize=16)
# Use fig.subplots_adjust instead of tight_layout for better control
fig.subplots_adjust(left=0.08, right=0.9, top=0.9, bottom=0.1)
# Ensure the results directory exists
os.makedirs("results", exist_ok=True)
output_files = []
output_file = f"results/phase_space_comparison{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
output_files.append(output_file)
# Create a difference plot to highlight changes
fig, ax = plt.subplots(figsize=(10, 8))
# Calculate the difference (last_snap - initial)
diff = C_last_snap - C_initial
# Use a diverging colormap for the difference plot with adjusted scale
# Calculate a reasonable scale based on the data
abs_max = max(abs(np.min(diff)), abs(np.max(diff)))
# Ensure we have a non-zero scale and use a slightly larger value to avoid clipping
vmax_diff = max(abs_max * 1.1, 1.0)
pcm_diff = ax.pcolormesh(X_initial, Y_initial, diff, shading='auto', cmap='RdBu_r',
vmin=-vmax_diff, vmax=vmax_diff) # Dynamic symmetric scale for differences
ax.set_xlabel(r'$r$ (kpc)', fontsize=12)
ax.set_ylabel(r'$v$ (km/s)', fontsize=12)
ax.set_title(fr'Phase Space Difference (Final - Initial)', fontsize=14)
ax.set_xlim(0, max_r_all)
ax.set_ylim(0, max_v_all)
cbar = fig.colorbar(pcm_diff)
cbar.set_label('Difference in Counts', fontsize=12)
diff_output_file = f"results/phase_space_difference{suffix}.png"
plt.savefig(diff_output_file, dpi=150)
plt.close()
output_files.append(diff_output_file)
# Start combined progress tracking for the plots
total_plots = len(output_files)
start_combined_progress("phase_space_plots", total_plots)
# Log both plots with progress indication
for i, output_file in enumerate(output_files, 1):
update_combined_progress("phase_space_plots", output_file)
return True
[docs]
def plot_convergence_test(Nint_arr, Nspl_arr, basefile, suffix, xlabel, ylabel, title, output_file):
"""
Generate convergence test plots comparing different integration and spline parameters.
Parameters
----------
Nint_arr : list
List of integration parameter values.
Nspl_arr : list
List of spline parameter values.
basefile : str
Base filename to use.
suffix : str
Suffix for input/output files.
xlabel : str
X-axis label (already formatted LaTeX string).
ylabel : str
Y-axis label (already formatted LaTeX string).
title : str
Plot title (already formatted LaTeX string).
output_file : str
Output file path.
"""
plt.figure(figsize=(10, 6))
# Determine number of columns based on basefile
ncols = ncol_convergence # Default for most convergence files
if basefile == "density_profile":
ncols = ncol_density_profile
elif basefile == "massprofile":
ncols = ncol_mass_profile
elif basefile == "Psiprofile":
ncols = ncol_psi_profile
elif basefile == "f_of_E":
ncols = ncol_f_of_E
elif basefile == "integrand":
ncols = ncol_integrand
for Nint in Nint_arr:
for Nspl in Nspl_arr:
filepath = f"data/{basefile}_Ni{Nint}_Ns{Nspl}{suffix}.dat"
if not os.path.exists(filepath):
continue
data = safe_load_and_filter_bin(filepath, ncols, dtype=np.float32)
if data is None:
continue
# Filter out rows with NaN or Inf values
mask = np.all(np.isfinite(data), axis=1)
filtered_data = data[mask]
if len(filtered_data) == 0:
continue
plt.plot(
filtered_data[:, 0],
filtered_data[:, 1],
label=r"$N_{\rm int}=%d$, $N_{\rm spl}=%d$" % (Nint, Nspl)
)
plt.xlabel(xlabel, fontsize=12)
plt.ylabel(ylabel, fontsize=12)
plt.title(title, fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(output_file, dpi=150)
plt.close()
# Log to file only, no console output
logger.info(f"Plot saved: {output_file}")
# Progress tracking is handled by the caller
[docs]
def generate_convergence_test_plots(suffix):
"""
Generate all convergence test plots.
Parameters
----------
suffix : str
Suffix for input/output files.
"""
# Define parameter arrays for testing
Nint_arr = [1000, 10000]
Nspl_arr = [1000, 10000]
# Define plot specifications (basefile, labels, title, output)
plot_specs = [
# Psi profile convergence test
("Psiprofile", r"$r$ (kpc)", r"$\Psi$ (km$^2$/s$^2$)", r"$\Psi(r)$ Profile Convergence Test",
f"results/Psiprofile_convergence_test{suffix}.png"),
# Integrand convergence test - updated labels based on C code analysis
("integrand", r'$\sqrt{\mathcal{E}_{max} - \Psi}$', r'$2 \, d\rho/d\Psi$', r"Integrand Convergence Test",
f"results/integrand_convergence_test{suffix}.png"),
# Distribution function convergence test - corrected units and symbol
("f_of_E", r"$\mathcal{E}$ (km$^2$/s$^2$)", r"$f(\mathcal{E})$ ((km/s)$^{-3}$ kpc$^{-3}$)", r"$f(\mathcal{E})$ Convergence Test",
f"results/f_of_E_convergence_test{suffix}.png"),
# Mass profile convergence test
("massprofile", r"$r$ (kpc)", r"$M(r)$ (M$_\odot$)", r"$M(r)$ Profile Convergence Test",
f"results/massprofile_convergence_test{suffix}.png"),
# Density profile convergence test - uses rho, not 4*pi*r^2*rho
("density_profile", r"$r$ (kpc)", r"$\rho(r)$ (M$_\odot$/kpc$^3$)", r"Density Profile $\rho(r)$ Convergence Test",
f"results/density_profile_convergence_test{suffix}.png")
]
# Start progress tracking for data loading
total_data_files = len(plot_specs)
start_combined_progress("convergence_data_loading", total_data_files)
# First check for data files and log loading progress
for i, (basefile, xlabel, ylabel, title, output_file) in enumerate(plot_specs, 1):
# Check for different Nint/Nspl combinations to see if any files exist
found_files = False
for Nint in Nint_arr:
for Nspl in Nspl_arr:
filepath = f"data/{basefile}_Ni{Nint}_Ns{Nspl}{suffix}.dat"
if os.path.exists(filepath):
found_files = True
logger.info(f"Found convergence test file: {filepath}")
break
if found_files:
break
if found_files:
logger.info(f"Loading data for {basefile} convergence test")
else:
logger.warning(f"No data files found for {basefile} convergence test")
# Update progress with a more filename-like format
if found_files:
# Use the first found file as the reference
for Nint in Nint_arr:
for Nspl in Nspl_arr:
filepath = f"data/{basefile}_Ni{Nint}_Ns{Nspl}{suffix}.dat"
if os.path.exists(filepath):
update_combined_progress("convergence_data_loading", filepath)
break
if os.path.exists(filepath):
break
else:
# If no files were found, still use a filename-like format
dummy_filepath = f"data/{basefile}_convergence{suffix}.dat"
update_combined_progress("convergence_data_loading", dummy_filepath)
# Initialize combined progress tracking for plot generation
total_plots = len(plot_specs)
start_combined_progress("convergence_tests", total_plots)
# Generate all plots with combined progress tracking
for i, (basefile, xlabel, ylabel, title, output_file) in enumerate(plot_specs, 1):
try:
plot_convergence_test(Nint_arr, Nspl_arr, basefile, suffix, xlabel, ylabel, title, output_file)
# Update the combined progress bar
update_combined_progress("convergence_tests", output_file)
except Exception as e:
logger.error(f"Error generating {basefile} convergence test: {str(e)}")
print_status(f"Error generating {basefile} convergence test: {str(e)}")
continue
[docs]
def process_rank_files(suffix, start_snap, end_snap, step_snap):
"""
Processes sorted and unsorted Rank snapshot files using multiprocessing.
- **Sorted Files:** Calls worker `process_rank_file_for_1d_anim` for each
filtered sorted file. This worker loads necessary columns, uses
linear interpolation/downsampling, and returns processed arrays. The
results are used to populate the global `mass_snapshots`,
`density_snapshots`, and `psi_snapshots` lists directly.
- **Unsorted Files:** Calls worker `process_unsorted_rank_file` for each
unsorted file. This worker uses optimized seeking to load energy
data for the first few particles.
- Progress is displayed using tqdm.
Parameters
----------
suffix : str
Suffix for data files.
start_snap : int
First snapshot number to include.
end_snap : int
Last snapshot number to include (0 means use max available).
step_snap : int
Step size between snapshots.
Returns
-------
tuple
(None, unsorted_energy_list): The first element is always None as
sorted data is placed directly into global lists. The second element
is a list of (snap, energy_values) tuples from unsorted files.
"""
global mass_snapshots, density_snapshots, psi_snapshots
# Clear global snapshot data first to ensure we start fresh
mass_snapshots.clear()
density_snapshots.clear()
psi_snapshots.clear()
# Find and process sorted rank files
Rank_pattern = f"data/Rank_Mass_Rad_VRad_sorted_t*{suffix}.dat"
all_rank_files = glob.glob(Rank_pattern)
# Filter to keep ONLY files that match the exact pattern:
# data/Rank_Mass_Rad_VRad_sorted_t00001_40000_1001_5.dat
# without any extra characters between "t00001" and the suffix
correct_pattern = re.compile(r'data/Rank_Mass_Rad_VRad_sorted_t\d+' + re.escape(suffix) + r'\.dat$')
all_rank_files = [f for f in all_rank_files if correct_pattern.match(f)]
# Debug output
global enable_logging
if enable_logging:
log_message(f"Found {len(all_rank_files)} rank files after filtering.")
# Find and process unsorted rank files
unsorted_pattern = f"data/Rank_Mass_Rad_VRad_unsorted_t*{suffix}.dat"
unsorted_files = glob.glob(unsorted_pattern)
# Filter to keep ONLY files that match the exact pattern, same as with the sorted files
correct_unsorted_pattern = re.compile(r'data/Rank_Mass_Rad_VRad_unsorted_t\d+' + re.escape(suffix) + r'\.dat$')
unsorted_files = [f for f in unsorted_files if correct_unsorted_pattern.match(f)]
logger.info(f"Found {len(all_rank_files)} rank sorted snapshot files and {len(unsorted_files)} unsorted snapshot files.")
# Process sorted files for 1D Animations
if not all_rank_files:
logger.warning("No rank sorted snapshot files found for animation.")
print_status("No rank sorted snapshot files found for animation.")
else:
# Sort files by snapshot number
pattern = re.compile(r'Rank_Mass_Rad_VRad_sorted_t(\d+)')
all_rank_files.sort(key=lambda filename: get_snapshot_number(filename, pattern))
# Filter files based on start/end/step if specified
filtered_rank_files = []
for fname in all_rank_files:
snap = get_snapshot_number(fname, pattern)
# Skip files outside the requested range
if start_snap > 0 and snap < start_snap:
continue
if end_snap > 0 and snap > end_snap:
continue
# Skip files not on the requested step
if step_snap > 1 and (snap - start_snap) % step_snap != 0:
continue
filtered_rank_files.append(fname)
logger.info(f"Processing {len(filtered_rank_files)} sorted files for 1D animations...")
with mp.Pool(mp.cpu_count()) as pool:
# Setup tqdm bars
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('proc_sorted_snaps', term_width)
counter_tqdm = tqdm(
total=len(filtered_rank_files),
desc=selected_desc, # Use dynamic description
unit="file",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True, # Allow adapting to terminal width
ncols=None,
bar_format=selected_bar_format
)
# For alignment with the counter line
bar_tqdm = tqdm(
total=len(filtered_rank_files),
position=1,
leave=True,
dynamic_ncols=True,
bar_format="{bar} {percentage:3.1f}%",
ascii=False # Use Unicode block characters
)
# Process files and handle results as they arrive
processed_count_1d = 0
# Prepare arguments including the suffix for the 1D animation worker
args_for_1d_anim = [(fname, PROJECT_ROOT, suffix) for fname in filtered_rank_files]
# Pass the modified arguments list to imap_unordered
for result in pool.imap_unordered(process_rank_file_for_1d_anim, args_for_1d_anim, chunksize=2):
if result is not None:
snap, radii, mass, density, psi = result
# Append directly to global lists
mass_snapshots.append((snap, radii, mass))
density_snapshots.append((snap, radii, density))
psi_snapshots.append((snap, radii, psi))
processed_count_1d += 1
# Update progress bars
counter_tqdm.update(1)
bar_tqdm.update(1)
if counter_tqdm.n % 50 == 0: gc.collect() # Periodic GC in main
counter_tqdm.close()
bar_tqdm.close()
gc.collect() # Collect after pool
# Add separator line after tqdm progress bar
sys.stdout.write(get_separator_line(char='-') + "\n")
# Sort the snapshot lists by snap number AFTER collecting all results
mass_snapshots.sort(key=lambda x: x[0])
density_snapshots.sort(key=lambda x: x[0])
psi_snapshots.sort(key=lambda x: x[0])
logger.info(f"Finished processing {processed_count_1d} sorted files for 1D animations.")
# Process unsorted files
unsorted_energy_list = []
if unsorted_files:
# Process unsorted files for basic energy plot in parallel
logger.info(f"Processing {len(unsorted_files)} unsorted snapshot files...")
with mp.Pool(mp.cpu_count()) as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('proc_unsorted_snaps', term_width)
counter_tqdm = tqdm(
total=len(unsorted_files),
desc=selected_desc, # Use dynamic description
unit="file",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True, # Allow adapting to terminal width
ncols=None,
bar_format=selected_bar_format
)
# For alignment with the counter line, we'll use a fixed width
bar_tqdm = tqdm(
total=len(unsorted_files),
position=1,
leave=True,
dynamic_ncols=True,
bar_format="{bar} {percentage:3.1f}%",
ascii=False # Use Unicode block characters
)
# Process the files with a custom callback to update both progress bars
results = []
# Prepare arguments including suffix for the unsorted worker
args_for_unsorted = [(fname, PROJECT_ROOT, suffix) for fname in unsorted_files]
# Pass modified args to imap
for result in pool.imap(process_unsorted_rank_file, args_for_unsorted):
results.append(result)
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Add delay after progress bars if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
unsorted_energy_list = results
# Add separator line after tqdm progress bar
sys.stdout.write(get_separator_line(char='-') + "\n")
# Remove any files that failed to process
unsorted_energy_list = [x for x in unsorted_energy_list if x is not None]
if not unsorted_energy_list:
logger.warning("No valid unsorted energy data found.")
print_status("No valid unsorted energy data found.")
else:
logger.warning("No unsorted snapshot files found for energy plots.")
print_status("No unsorted snapshot files found for energy plots.")
# Get the maximum available snapshot number
max_available_snap = 0
if all_rank_files:
match = pattern.search(all_rank_files[-1])
if match:
max_available_snap = int(match.group(1))
# If end_snap is 0, use the maximum available snapshot
local_end_snap = end_snap
if local_end_snap == 0:
local_end_snap = max_available_snap
logger.info(f"Using snapshot range: {start_snap} to {local_end_snap} with step {step_snap}")
# Filter files based on start, end, and step parameters
filtered_rank_files = []
for f in all_rank_files:
match = pattern.search(f)
if match:
snap_num = int(match.group(1))
if start_snap <= snap_num <= local_end_snap and (snap_num - start_snap) % step_snap == 0:
filtered_rank_files.append(f)
logger.info(f"Found {len(filtered_rank_files)} rank sorted snapshot files in the specified range.")
# Sort unsorted energy data by snapshot number
if unsorted_energy_list:
unsorted_energy_list.sort(key=lambda x: x[0])
print_status("Particle snapshot data processing complete.")
# Add blank line after completion message
print()
# Return the processed data
return None, unsorted_energy_list
[docs]
def generate_unsorted_energy_plot(unsorted_energy_list, suffix):
"""
Generate energy plot from unsorted rank files.
Parameters
----------
unsorted_energy_list : list
Pre-processed unsorted energy data.
suffix : str
Suffix for input/output files.
Returns
-------
str or None
Path to the generated plot file, or None if no plot was generated.
"""
if not unsorted_energy_list:
logger.info("No unsorted energy data available for plotting.")
return None
# Get snapshot numbers
unsorted_snaps = [snap for (snap, data) in unsorted_energy_list]
# Extract tfinal_factor from suffix if possible
# Format is typically _[file_tag]_npts_Ntimes_tfinal_factor
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Convert snapshot numbers to dynamical times
total_snapshots = len(unsorted_snaps)
if total_snapshots > 1:
dyn_times = [snap / (total_snapshots - 1) * tfinal_factor for snap in unsorted_snaps]
else:
dyn_times = [0] # Default if only one snapshot
unsorted_E_matrix = np.array([data for (snap, data) in unsorted_energy_list])
if unsorted_E_matrix.ndim >= 2 and unsorted_E_matrix.shape[0] > 0 and unsorted_E_matrix.shape[1] > 0:
plt.figure(figsize=(10, 6))
for row_idx in range(min(10, unsorted_E_matrix.shape[1])):
plt.plot(dyn_times, unsorted_E_matrix[:, row_idx]) # Labels removed for clarity
plt.xlabel(r'$t$ ($t_{\rm dyn}$)', fontsize=12)
plt.ylabel(r'$\mathcal{E}$ (km$^2$/s$^2$)', fontsize=12)
plt.title(r'Energy vs. Time (Random Sample)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
# Legend removed
output_file = f"results/Energy_vs_timestep_unsorted{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
logger.info(f"Saved unsorted energy plot: {output_file}")
# Return the output file path for progress tracking
return output_file
return None
[docs]
def generate_sorted_energy_plot(suffix):
"""
Generate energy plot from sorted rank files (particles with lowest initial radius).
Uses the efficient `process_sorted_energy_file` worker which employs
seek-based loading to extract energy data only for specific particles.
Processes results iteratively to build the final matrix, avoiding
intermediate data accumulation.
Parameters
----------
suffix : str
Suffix for input/output files.
Returns
-------
str or None
Path to the generated plot file, or None if no plot was generated.
Notes
-----
This function uses memory-efficient file reading by:
1. Identifying specific particle IDs to track
2. Using optimized file reading that seeks to specific rows
3. Processing data incrementally to avoid large intermediate arrays
"""
try:
# Load the particle IDs to track (particles with lowest initial radius)
lowest_radius_ids_file = f"data/lowest_radius_ids{suffix}.dat"
# Load the lowest_radius_ids file using the binary file handler
id_data = safe_load_and_filter_bin(lowest_radius_ids_file, ncols=2, dtype=[np.int32, np.float32])
if id_data is None or len(id_data) < 1:
print_status(f"Failed to load particle IDs from {lowest_radius_ids_file}")
return None
# Extract the first 10 particles to track (or fewer if less available)
max_particles = min(10, id_data.shape[0])
tracked_radii = id_data[:max_particles, 1]
# Update the global tracked_ids_for_energy
global tracked_ids_for_energy
tracked_ids_for_energy = id_data[:max_particles, 0].astype(int)
# Collect unsorted data for these specific particles in parallel
unsorted_pattern = f"data/Rank_Mass_Rad_VRad_unsorted_t*{suffix}.dat"
unsorted_files_for_sort = glob.glob(unsorted_pattern)
# Filter to keep ONLY files that match the exact pattern without any extra characters between "t00001" and the suffix
correct_unsorted_pattern = re.compile(r'data/Rank_Mass_Rad_VRad_unsorted_t\d+' + re.escape(suffix) + r'\.dat$')
unsorted_files_for_sort = [f for f in unsorted_files_for_sort if correct_unsorted_pattern.match(f)]
# Log the number of files found after filtering
logger.info(f"Processing {len(unsorted_files_for_sort)} files for sorted energy plot after filtering.")
if not unsorted_files_for_sort:
print_status("No unsorted files found for sorted energy plot.")
return None
# Process files in parallel using the global function
# Log detailed info to file but show condensed version on console
logger.info(f"Processing {len(unsorted_files_for_sort)} files for energy-time plot...")
with mp.Pool(mp.cpu_count()) as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('proc_energy_series', term_width)
counter_tqdm = tqdm(
total=len(unsorted_files_for_sort),
desc=selected_desc,
unit="file",
position=0, # Position 0 to be the first bar shown
leave=True, # Keep the progress bar visible after completion
miniters=1,
dynamic_ncols=True, # Allow adapting to terminal width
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=len(unsorted_files_for_sort),
position=1, # Position 1 to appear below the counter
leave=True, # Keep the progress bar visible after completion
dynamic_ncols=True,
bar_format="{bar} {percentage:3.1f}%",
ascii=False # Use Unicode block characters
)
# Process the files with a custom callback to update both progress bars
results = []
# Prepare arguments for the worker, including the suffix
args_for_sorted_energy = [(fname, suffix) for fname in unsorted_files_for_sort]
# Pass the modified arguments
for result in pool.imap(process_sorted_energy_file, args_for_sorted_energy):
results.append(result)
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Add delay after progress bars if enabled
if progress_delay > 0 and paced_mode:
show_progress_delay(progress_delay)
# Add a separator line after processing
print(get_separator_line(char='-'))
sorted_energy_list = results
# Remove any files that failed to process
sorted_energy_list = [x for x in sorted_energy_list if x is not None]
if not sorted_energy_list:
print_status("No valid data found for energy-time sorted plot.")
return None
sorted_energy_list.sort(key=lambda x: x[0])
sorted_snaps = [snap for (snap, data) in sorted_energy_list]
# Preallocate a matrix of NaN values with shape (num_snapshots, num_particles)
num_snapshots = len(sorted_energy_list)
expected_particles = max_particles
# Create a matrix filled with NaN values initially
logger.info(f"Creating energy matrix for {num_snapshots} snapshots and {expected_particles} particles")
sorted_E_matrix = np.full((num_snapshots, expected_particles), np.nan)
# Fill in the actual values from each snapshot
for i, (snap, data) in enumerate(sorted_energy_list):
# Log useful information about the data shape
particle_count = min(data.shape[0], expected_particles)
logger.info(f"Snapshot {snap}: Found {particle_count} of {expected_particles} expected particles")
# Fill in available data (may be fewer than expected_particles)
for j in range(min(data.shape[0], expected_particles)):
if j < data.shape[0]: # Make sure we don't exceed the data bounds
sorted_E_matrix[i, j] = data[j, 5] # Column 5 is energy
# Log the final data shape
logger.info(f"Final energy matrix shape: {sorted_E_matrix.shape}")
# Extract tfinal_factor from suffix for time conversion
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Convert snapshot numbers to dynamical times once (for all particles)
total_snapshots = len(sorted_snaps)
if total_snapshots > 1:
dyn_times = [snap / (total_snapshots - 1) * tfinal_factor for snap in sorted_snaps]
else:
dyn_times = [0] # Default if only one snapshot
# Extract initial energies from the first snapshot
initial_energies = {}
if sorted_energy_list and len(sorted_energy_list) > 0:
first_snap, first_data = sorted_energy_list[0]
for i, row in enumerate(first_data):
if i < len(tracked_ids_for_energy):
initial_energies[tracked_ids_for_energy[i]] = row[5] # Energy is in column 5
if sorted_E_matrix.ndim >= 2 and sorted_E_matrix.shape[0] > 0 and sorted_E_matrix.shape[1] > 0:
plt.figure(figsize=(10, 6))
for row_idx in range(min(max_particles, sorted_E_matrix.shape[1])):
particle_id = tracked_ids_for_energy[row_idx]
init_e = initial_energies.get(particle_id, "N/A")
# Format initial energy with scientific notation if it's a number
if init_e != "N/A":
init_e_str = f"{init_e:.2e}"
else:
init_e_str = "N/A"
# Get the data for this particle
particle_data = sorted_E_matrix[:, row_idx]
# Check if we have any non-NaN data for this particle
if np.any(~np.isnan(particle_data)):
# Use masked array to ignore NaN values
valid_indices = ~np.isnan(particle_data)
if np.any(valid_indices):
valid_times = np.array(dyn_times)[valid_indices]
valid_data = particle_data[valid_indices]
# Log how much data was found for this particle
data_percentage = np.sum(valid_indices) / len(valid_indices) * 100
logger.info(f"Particle {particle_id}: Found {np.sum(valid_indices)}/{len(valid_indices)} data points ({data_percentage:.1f}%)")
# Only plot if at least some valid data points exist
if len(valid_data) > 0:
# Use LaTeX for legend
plt.plot(valid_times, valid_data,
label=fr"ID {particle_id} ($r_0$={tracked_radii[row_idx]:.4f}, $\mathcal{{E}}_0$={init_e_str})")
else:
logger.warning(f"No valid data for particle ID {particle_id}")
else:
logger.warning(f"No data available for particle ID {particle_id}")
plt.xlabel(r'$t$ ($t_{\rm dyn}$)', fontsize=12)
plt.ylabel(r'$\mathcal{E}$ (km$^2$/s$^2$)', fontsize=12)
plt.title(r'Energy vs. Time (Lowest Initial Radius)', fontsize=14)
# Legend removed for clarity
plt.grid(True, linestyle='--', alpha=0.7)
output_file = f"results/Energy_vs_timestep_sorted{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
logger.info(f"Saved sorted energy plot: {output_file}")
# Return the output file path for progress tracking
return output_file
except Exception as e:
logger.error(f"Failed to generate energy-time sorted plot: {e}")
logger.error(traceback.format_exc())
return None
[docs]
def create_mass_animation(suffix, duration):
"""
Create mass profile animation.
Parameters
----------
suffix : str
Suffix for input/output files.
duration : float
Duration of each frame in milliseconds.
Returns
-------
None
Function does not return a value, but saves the animation to a file.
Notes
-----
Uses parallel processing for rendering animation frames.
Saves the animation incrementally using `imageio.get_writer` to reduce
peak memory usage compared to collecting all frames first.
Requires imageio v2 (`pip install imageio==2.*`).
"""
global mass_snapshots, mass_max_value
if not mass_snapshots:
print_status("No mass data available for animation.")
return
# Calculate global maximum values for consistent scaling
calculate_global_max_values()
total_frames = len(mass_snapshots)
# Log detailed info to file only
log_message(f"Generating mass profile frames for {total_frames} snapshots...")
# Extract tfinal_factor from suffix if possible
# Format is typically _[file_tag]_npts_Ntimes_tfinal_factor
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Calculate r_max for consistent plotting - do this once before creating the pool
r_max = 1.1 * np.max([np.max(r) for _, r, _ in mass_snapshots if len(r) > 0])
r_max = min(200, r_max) # Cap at 200 kpc for reasonable display
# Create frame_data tuples with the actual snapshot data, tfinal_factor, total_frames, r_max and project_root_path
if not mass_snapshots:
print_status("Error: mass_snapshots list is empty before pool creation.")
return # Handle error appropriately
frame_data_list = [(mass_snapshots[i], tfinal_factor, total_frames, r_max, PROJECT_ROOT) for i in range(total_frames)]
# Set up imageio writer
mass_anim_output = f"results/Mass_Profile_Animation{suffix}.gif"
# Use seconds per frame for imageio v2 duration
frame_duration_sec_v2 = duration / 1000.0 # Convert ms to seconds
try:
# Use mode='I' for multiple images, loop=0 for infinite loop
writer = imageio.get_writer(
mass_anim_output,
format='GIF-PIL', # Explicitly use Pillow
mode='I',
# quantizer='nq', # Temporarily removed
palettesize=256, # Ensure full palette
duration=frame_duration_sec_v2,
loop=0
)
except Exception as e:
print_status(f"Error creating GIF writer: {e}")
return # Cannot proceed without writer
with mp.Pool(mp.cpu_count()) as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('gen_mass_frames', term_width)
counter_tqdm = tqdm(
total=total_frames,
desc=selected_desc, # Use dynamic description
unit="frame",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True,
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=total_frames,
position=1,
leave=True,
dynamic_ncols=True,
ncols=None,
bar_format="{bar} {percentage:3.1f}%",
ascii=False
)
# Process the frames and append directly to the writer
frame_count = 0
for frame_image in pool.imap(render_mass_frame, frame_data_list):
if frame_image is not None:
try:
writer.append_data(frame_image) # Append directly to writer
frame_count += 1
except Exception as e:
log_message(f"Error appending frame {frame_count+1}: {e}", level="error")
break
# Update progress bars
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Close the writer
try:
writer.close()
except Exception as e:
log_message(f"Error closing GIF writer: {e}", level="error")
# No delay between frame generation and animation saving to improve fluidity
# Log to file
log_message(f"Generated and saved {frame_count}/{total_frames} mass profile frames successfully to {mass_anim_output}")
if frame_count > 0:
print_status(f"Animation saved: {get_file_prefix(mass_anim_output)}")
# Add separator line after animation completion
sys.stdout.write(get_separator_line(char='-') + "\n")
else:
print_status("Failed to generate any mass profile frames.")
# Encourage garbage collection
gc.collect()
[docs]
def create_density_animation(suffix, duration):
"""
Create density profile animation.
Parameters
----------
suffix : str
Suffix for input/output files.
duration : float
Duration of each frame in milliseconds.
Returns
-------
None
Function does not return a value, but saves the animation to a file.
Notes
-----
Uses parallel processing for rendering animation frames.
Saves the animation incrementally using `imageio.get_writer` to reduce
peak memory usage compared to collecting all frames first.
Requires imageio v2 (`pip install imageio==2.*`).
"""
global density_snapshots, density_max_value
if not density_snapshots:
print_status("No density data available for animation.")
return
# Calculate global maximum values for consistent scaling
# This only needs to be called once, but we call it for each animation
# to ensure it's always calculated even if only one animation is requested
calculate_global_max_values()
total_frames = len(density_snapshots)
# Log detailed info to file only
log_message(f"Generating density profile frames for {total_frames} snapshots...")
# Extract tfinal_factor from suffix if possible
# Format is typically _[file_tag]_npts_Ntimes_tfinal_factor
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Calculate r_max for consistent plotting - do this once before creating the pool
r_max = 1.1 * np.max([np.max(r) for _, r, _ in density_snapshots if len(r) > 0])
r_max = min(300, r_max) # Cap at 300 kpc for reasonable display
# Create frame_data tuples with the actual snapshot data, tfinal_factor, total_frames, r_max and project_root_path
if not density_snapshots:
print_status("Error: density_snapshots list is empty before pool creation.")
return # Handle error appropriately
frame_data_list = [(density_snapshots[i], tfinal_factor, total_frames, r_max, PROJECT_ROOT) for i in range(total_frames)]
# Set up imageio writer
density_anim_output = f"results/Density_Profile_Animation{suffix}.gif"
# Use seconds per frame for imageio v2 duration
frame_duration_sec_v2 = duration / 1000.0 # Convert ms to seconds
try:
# Use mode='I' for multiple images, loop=0 for infinite loop
writer = imageio.get_writer(
density_anim_output,
format='GIF-PIL', # Explicitly use Pillow
mode='I',
# quantizer='nq', # Temporarily removed
palettesize=256, # Ensure full palette
duration=frame_duration_sec_v2,
loop=0
)
except Exception as e:
print_status(f"Error creating GIF writer: {e}")
return # Cannot proceed without writer
with mp.Pool(mp.cpu_count()) as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('gen_dens_frames', term_width)
counter_tqdm = tqdm(
total=total_frames,
desc=selected_desc, # Use dynamic description
unit="frame",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True,
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=total_frames,
position=1,
leave=True,
dynamic_ncols=True,
ncols=None,
bar_format="{bar} {percentage:3.1f}%",
ascii=False
)
# Process the frames and append directly to the writer
frame_count = 0
for frame_image in pool.imap(render_density_frame, frame_data_list):
if frame_image is not None:
try:
writer.append_data(frame_image) # Append directly to writer
frame_count += 1
except Exception as e:
log_message(f"Error appending frame {frame_count+1}: {e}", level="error")
break
# Update progress bars
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Close the writer
try:
writer.close()
except Exception as e:
log_message(f"Error closing GIF writer: {e}", level="error")
# No delay between frame generation and animation saving to improve fluidity
# Log to file
log_message(f"Generated and saved {frame_count}/{total_frames} density profile frames successfully to {density_anim_output}")
if frame_count > 0:
print_status(f"Animation saved: {get_file_prefix(density_anim_output)}")
# Add separator line after animation completion
sys.stdout.write(get_separator_line(char='-') + "\n")
else:
print_status("Failed to generate any density profile frames.")
# Encourage garbage collection
gc.collect()
[docs]
def create_psi_animation(suffix, duration):
"""
Create psi profile animation.
Parameters
----------
suffix : str
Suffix for input/output files.
duration : float
Duration of each frame in milliseconds.
Returns
-------
None
Function does not return a value, but saves the animation to a file.
Notes
-----
Uses parallel processing for rendering animation frames.
Saves the animation incrementally using `imageio.get_writer` to reduce
peak memory usage compared to collecting all frames first.
Requires imageio v2 (`pip install imageio==2.*`).
"""
global psi_snapshots, psi_max_value
if not psi_snapshots:
print_status("No psi data available for animation.")
return
# Calculate global maximum values for consistent scaling
# This only needs to be called once, but we call it for each animation
# to ensure it's always calculated even if only one animation is requested
calculate_global_max_values()
total_frames = len(psi_snapshots)
# Log detailed info to file only
log_message(f"Generating psi profile frames for {total_frames} snapshots...")
# Extract tfinal_factor from suffix if possible
# Format is typically _[file_tag]_npts_Ntimes_tfinal_factor
tfinal_factor = 5 # Default value
parts = suffix.strip('_').split('_')
if len(parts) >= 3:
try:
tfinal_factor = int(parts[-1])
except (ValueError, IndexError):
# Use default if parsing fails
pass
# Calculate r_max for consistent plotting - do this once before creating the pool
r_max = 1.1 * np.max([np.max(r) for _, r, _ in psi_snapshots if len(r) > 0])
r_max = min(250, r_max) # Cap at 250 kpc for reasonable display
# Create frame_data tuples with the actual snapshot data, tfinal_factor, total_frames, r_max and project_root_path
if not psi_snapshots:
print_status("Error: psi_snapshots list is empty before pool creation.")
return # Handle error appropriately
frame_data_list = [(psi_snapshots[i], tfinal_factor, total_frames, r_max, PROJECT_ROOT) for i in range(total_frames)]
# Set up imageio writer
psi_anim_output = f"results/Psi_Profile_Animation{suffix}.gif"
# Use seconds per frame for imageio v2 duration
frame_duration_sec_v2 = duration / 1000.0 # Convert ms to seconds
try:
# Use mode='I' for multiple images, loop=0 for infinite loop
writer = imageio.get_writer(
psi_anim_output,
format='GIF-PIL', # Explicitly use Pillow
mode='I',
# quantizer='nq', # Temporarily removed
palettesize=256, # Ensure full palette
duration=frame_duration_sec_v2,
loop=0
)
except Exception as e:
print_status(f"Error creating GIF writer: {e}")
return # Cannot proceed without writer
with mp.Pool(mp.cpu_count()) as pool:
# Use a custom tqdm instance with two progress displays
# First, create a standard tqdm for the counter on the first line
# Determine description and format dynamically
try:
term_width = shutil.get_terminal_size().columns
except OSError:
term_width = 100 # Fallback width
selected_desc, selected_bar_format = select_tqdm_format('gen_psi_frames', term_width)
counter_tqdm = tqdm(
total=total_frames,
desc=selected_desc, # Use dynamic description
unit="frame",
position=0,
leave=True,
miniters=1,
dynamic_ncols=True,
ncols=None,
bar_format=selected_bar_format
)
bar_tqdm = tqdm(
total=total_frames,
position=1,
leave=True,
dynamic_ncols=True,
ncols=None,
bar_format="{bar} {percentage:3.1f}%",
ascii=False
)
# Process the frames and append directly to the writer
frame_count = 0
for frame_image in pool.imap(render_psi_frame, frame_data_list):
if frame_image is not None:
try:
writer.append_data(frame_image) # Append directly to writer
frame_count += 1
except Exception as e:
log_message(f"Error appending frame {frame_count+1}: {e}", level="error")
break
# Update progress bars
counter_tqdm.update(1)
bar_tqdm.update(1)
counter_tqdm.close()
bar_tqdm.close()
# Close the writer
try:
writer.close()
except Exception as e:
log_message(f"Error closing GIF writer: {e}", level="error")
# No delay between frame generation and animation saving to improve fluidity
# Log to file
log_message(f"Generated and saved {frame_count}/{total_frames} psi profile frames successfully to {psi_anim_output}")
if frame_count > 0:
print_status(f"Animation saved: {get_file_prefix(psi_anim_output)}")
# Add separator line after animation completion
sys.stdout.write(get_separator_line(char='-') + "\n")
else:
print_status("Failed to generate any psi profile frames.")
# Encourage garbage collection
gc.collect()
[docs]
def process_single_rank_file_for_histogram(suffix):
"""
Process a single rank file (snapshot 0) for generating a histogram.
This is more efficient than processing all snapshots when we only need one.
Parameters
----------
suffix : str
Suffix for input/output files.
Returns
-------
numpy.ndarray or None
The processed data from snapshot 0, or None if not available.
"""
# Find the initial snapshot file
initial_file = f"data/Rank_Mass_Rad_VRad_sorted_t00000{suffix}.dat"
# If the initial file doesn't exist, try to find any sorted rank file
if not os.path.exists(initial_file):
rank_files = glob.glob(f"data/Rank_Mass_Rad_VRad_sorted_t*{suffix}.dat")
if rank_files:
# Sort files to ensure we get the earliest snapshot
rank_files.sort()
initial_file = rank_files[0]
else:
print("No rank files found for histogram.")
return None
data = safe_load_and_filter_bin(initial_file, ncol_Rank_Mass_Rad_VRad_sorted,
dtype=[np.int32, np.float32, np.float32, np.float32,
np.float32, np.float32, np.float32, np.float32])
if data is None or data.shape[0] == 0:
print(f"No data to process in {initial_file}")
return None
return data
[docs]
def generate_rank_histogram(rank_decimated_data, suffix, output_file=None):
"""
Generate 2D histogram from rank data.
Parameters
----------
rank_decimated_data : list or numpy.ndarray
Processed rank data, either as a list of (snap, data) tuples
or as a single data array for a specific snapshot.
suffix : str
Suffix for input/output files.
output_file : str, optional
Custom output file path. If None, a default path will be used.
"""
logger.info("Generating histograms from rank data...")
# Check if input is a list of (snap, data) tuples or a single data array
if isinstance(rank_decimated_data, list):
# Processing list of (snap, data) tuples - locate snapshot 0
initial_data = None
for snap, decimated in rank_decimated_data:
if snap == 0:
initial_data = decimated
break
if initial_data is None:
logger.warning("No initial snapshot found in the existing data")
return None
else:
# Processing single data array (from process_single_rank_file_for_histogram)
initial_data = rank_decimated_data
# Extract columns from the decimated data
ranks = initial_data[:, 0]
masses = initial_data[:, 1]
radii = initial_data[:, 2]
vrad = initial_data[:, 3]
psi = initial_data[:, 4]
energy = initial_data[:, 5]
angular_momentum = initial_data[:, 6]
# Calculate total velocity (including angular momentum component)
# Filter out zero radii before division
nonzero_mask = radii > 0
radii_nz = radii[nonzero_mask]
vrad_nz = vrad[nonzero_mask]
angular_momentum_nz = angular_momentum[nonzero_mask]
if len(radii_nz) == 0:
logger.warning("No particles with non-zero radius found in rank data for histogram.")
return None
# Calculate the total velocity in simulation internal units
total_velocity_internal = np.sqrt(vrad_nz**2 + (angular_momentum_nz**2 / (radii_nz**2)))
# Convert to km/s
kmsec_to_kpcmyr = 1.0227e-3 # Conversion factor
total_velocity = total_velocity_internal * (1.0 / kmsec_to_kpcmyr)
# Set up histogram parameters
max_r_all = 250.0
max_v_all = 320.0
# Create the 2D histogram using non-zero radius data
hist, xedges, yedges = np.histogram2d(
radii_nz,
total_velocity,
bins=[200, 200],
range=[[0, max_r_all], [0, max_v_all]]
)
# Transpose for correct orientation
hist = hist.T
plt.figure(figsize=(8, 6))
plt.pcolormesh(xedges, yedges, hist, cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel(r'$r$ (kpc)', fontsize=12)
plt.ylabel(r'$v$ (km/s)', fontsize=12)
plt.title(r'Initial Phase Space Distribution', fontsize=14)
plt.xlim(0, 250)
plt.ylim(0, 320)
if output_file is None:
# Define default output filename
output_file = f"results/part_data_histogram_initial{suffix}.png"
plt.savefig(output_file, dpi=150)
plt.close()
# Log to file (progress tracking by the caller)
logger.info(f"Histogram saved: {output_file}")
return output_file
[docs]
def process_variable_histograms(config):
"""
Generates 1D variable distribution histograms comparing initial vs final distributions.
Creates histograms for radius, velocity, radial velocity, and angular momentum.
Parameters
----------
config : Configuration
The configuration object containing parsed arguments and settings.
Returns
-------
bool
True if successful, False otherwise.
"""
print_header("Generating 1D Variable Distribution Histograms")
# Determine file paths directly within this function
suffix = config.suffix
initial_file = f"data/particles{suffix}.dat"
final_unsorted_file = None
unsorted_pattern_glob = f"data/Rank_Mass_Rad_VRad_unsorted_t*{suffix}.dat"
log_message(f"Searching for unsorted files: {unsorted_pattern_glob}", level="debug")
unsorted_files = glob.glob(unsorted_pattern_glob)
log_message(f"Found {len(unsorted_files)} matching files.", level="debug")
if unsorted_files:
unsorted_regex_pattern = re.compile(r'Rank_Mass_Rad_VRad_unsorted_t(\d+)' + re.escape(suffix) + r'\.dat')
try:
unsorted_files.sort(key=lambda f: get_snapshot_number(f, pattern=unsorted_regex_pattern))
last_file_candidate = unsorted_files[-1]
if get_snapshot_number(last_file_candidate, pattern=unsorted_regex_pattern) != 999999999:
final_unsorted_file = last_file_candidate
log_message(f"Identified last unsorted snapshot: {os.path.basename(final_unsorted_file)}", level="info")
else:
log_message("Warning: Sorting unsorted files failed to identify latest.", level="warning")
except Exception as e:
log_message(f"Warning: Error sorting unsorted files: {e}. Using basic sort.", level="warning")
unsorted_files.sort()
if unsorted_files:
final_unsorted_file = unsorted_files[-1]
# Prepare files_to_load_info using determined paths
files_to_load_info = []
if os.path.exists(initial_file):
files_to_load_info.append((initial_file, ncol_particles_initial, np.float32))
else:
log_message(f"Initial particles file not found: {initial_file}", level="error")
print_status(f"Error: Initial particles file not found: {initial_file}")
return False # Exit the function if initial file missing
if final_unsorted_file and os.path.exists(final_unsorted_file):
files_to_load_info.append((final_unsorted_file, ncol_Rank_Mass_Rad_VRad_unsorted, np.float32))
else:
err_msg = f"Last unsorted snapshot file not found or determined for suffix {suffix}."
log_message(err_msg, level="error")
print_status(f"Error: {err_msg}")
return False # Exit the function if final file missing
# --- Data Loading Phase ---
log_message("Starting data loading phase.", level="info")
loading_section_key = "diagnostic_data_loading"
start_combined_progress(loading_section_key, len(files_to_load_info))
loaded_data = {}
load_success = True
for f_path, f_ncol, f_dtype in files_to_load_info:
log_message(f"Loading data from: {f_path}", level="info")
data = safe_load_and_filter_bin(f_path, f_ncol, f_dtype)
loaded_data[f_path] = data
if data is None:
log_message(f"Failed to load or filter data from: {f_path}", level="warning")
load_success = False
update_combined_progress(loading_section_key, f_path) # Pass file path for prefix
initial_raw_data = loaded_data.get(initial_file)
final_unsorted_raw_data = loaded_data.get(final_unsorted_file)
if not load_success or initial_raw_data is None or final_unsorted_raw_data is None:
log_message("Failed to load required data files.", level="error")
print_status("Error: Failed to load required data files.")
if loading_section_key in _combined_plot_trackers:
_combined_plot_trackers.pop(loading_section_key, None)
clear_line()
return False
log_message("Data loading complete.", level="info")
log_message("Processing particle data...", level="info")
init_r, init_vr, init_L, init_vtot = process_particle_data(initial_raw_data, ncol_particles_initial)
final_r, final_vr, final_L, final_vtot = process_particle_data(final_unsorted_raw_data, ncol_Rank_Mass_Rad_VRad_unsorted)
processed_data_valid = all(d is not None for d in [init_r, init_vr, init_L, init_vtot,
final_r, final_vr, final_L, final_vtot])
if not processed_data_valid:
log_message("Could not extract valid data after processing.", level="error")
print_status("Error: Could not extract valid data after processing.")
return False
log_message("Particle data processing complete.", level="info")
# --- Define plot specifications with refined labels and titles ---
plot_vars = [
{'data1': init_r, 'data2': final_r,
'xlabel': r'$r$ (kpc)',
'title': r'Comparison of Radius Distribution $N(r)$',
'filename': f'radius_hist_compare{suffix}.png', 'range': [0, 250]},
{'data1': init_vtot, 'data2': final_vtot,
'xlabel': r'$v$ (km/s)',
'title': r'Comparison of Total Velocity Distribution $N(v)$',
'filename': f'total_velocity_histogram_compare{suffix}.png', 'range': [0, 500]},
{'data1': init_vr, 'data2': final_vr,
'xlabel': r'$v_r$ (km/s)',
'title': r'Comparison of Radial Velocity Distribution $N(v_r)$',
'filename': f'radial_velocity_histogram_compare{suffix}.png', 'range': None}, # Auto range
{'data1': init_L, 'data2': final_L,
'xlabel': r'$\ell$ (simulation units)', # Use \ell
'title': r'Comparison of Angular Momentum Distribution $N(\ell)$', # Use \ell
'filename': f'angular_momentum_histogram_compare{suffix}.png', 'range': None} # Auto range
]
num_plots = len(plot_vars)
log_message(f"Generating {num_plots} comparison histograms...", level="info")
plots_generated = 0
# --- Plot Saving Phase ---
saving_section_key = "diagnostic_plot_saving"
start_combined_progress(saving_section_key, num_plots) # Uses default "Save:"
for i, spec in enumerate(plot_vars, 1):
output_path = os.path.join("results", spec['filename'])
log_message(f"Generating plot: {output_path}", level="info")
plt.figure(figsize=(10, 6))
bins = 100
hist_range = spec['range']
if hist_range is None:
combined = np.concatenate((spec['data1'], spec['data2']))
finite = combined[np.isfinite(combined)]
if finite.size > 0:
hist_range = (np.min(finite), np.max(finite))
else:
hist_range = (-1, 1)
try:
plt.hist(spec['data1'], bins=bins, range=hist_range, alpha=0.6, color='blue', label='Initial', density=True)
plt.hist(spec['data2'], bins=bins, range=hist_range, alpha=0.6, color='red', label='Final', density=True)
plt.xlabel(spec['xlabel'], fontsize=12)
plt.ylabel('Normalized Frequency', fontsize=12)
plt.title(spec['title'], fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7) # Match nsphere_plot grid style
plt.tight_layout()
plt.savefig(output_path, dpi=150) # Consistent DPI
plots_generated += 1
update_combined_progress(saving_section_key, output_path)
log_message(f"Plot saved: {output_path}", level="debug")
except Exception as e:
clear_line()
log_message(f"Error generating plot {output_path}: {e}", level="error")
print(f"\nError generating plot for {spec['xlabel']}: {e}")
update_combined_progress(saving_section_key, output_path) # Pass path for prefix
finally:
plt.close()
# Corrected Footer message logic
if plots_generated == num_plots:
print_footer("1D Variable Comparison Histograms generated successfully.")
log_message("1D Variable Comparison Histograms generated successfully.", level="info")
return True
else:
final_message = f"1D Comparison Histograms partially generated ({plots_generated}/{num_plots})."
print_status(final_message + " Check log for details.")
log_message(final_message + " Some plots failed.", level="warning")
return False
[docs]
def generate_all_2d_histograms(config, rank_decimated_data=None):
"""
Generates all 2D histogram plots including particle, nsphere, and rank histograms.
Parameters
----------
config : Configuration
The configuration object.
rank_decimated_data : list or None
Pre-processed rank data that can be used as fallback for histogram generation.
"""
print_header("Generating 2D Histograms")
# Count total histogram plots (4 from particles and nsphere, 1 from rank)
total_histograms = 5
start_combined_progress("histogram_plots", total_histograms)
# Process particles histograms with combined progress
plot_particles_histograms(config.suffix, progress_callback=lambda output_file: update_combined_progress("histogram_plots", output_file))
# Process nsphere histograms with combined progress
plot_nsphere_histograms(config.suffix, progress_callback=lambda output_file: update_combined_progress("histogram_plots", output_file))
# For rank histogram, only a single file is needed
logger.info("Generating histograms from rank data...")
single_rank_data = process_single_rank_file_for_histogram(config.suffix)
# Generate rank histogram if we have data
rank_output_file = f"results/part_data_histogram_initial{config.suffix}.png"
if single_rank_data is not None:
generate_rank_histogram(single_rank_data, config.suffix, rank_output_file)
update_combined_progress("histogram_plots", rank_output_file)
# No fallback available for histogram generation
print_footer("2D histograms generated successfully.")
[docs]
def generate_all_energy_plots(config, file_paths, unsorted_energy_data):
"""
Generates all energy-related plots with status display and progress tracking.
Displays the initial status, generates plots in parallel (unsorted and sorted
energy plots), and updates the console output with completion status.
Parameters
----------
config : Configuration
The configuration object containing suffix and other settings.
file_paths : dict or None
Dictionary of file paths, will be created if None.
unsorted_energy_data : list or None
Pre-processed unsorted energy data from process_rank_files.
Notes
-----
The console output is managed with a series of up/down cursor movements
to create a clean, consistently updated display. Initial status appears
immediately while plots are being generated, with the final update
overwriting the status lines upon completion.
"""
# Access the global paced_mode variable for the timer function
global paced_mode
# Setup file paths if they weren't provided
if file_paths is None:
file_paths = config.setup_file_paths()
print_header("Generating Energy Plots", add_newline=False)
# Import threading module
import threading
# Prepare static display lines
energy_plot_start_time = time.time()
unsorted_name = f"Energy_vs_timestep_unsorted{config.suffix}"
sorted_name = f"Energy_vs_timestep_sorted{config.suffix}"
# Log start of processing
logger.info("Energy plots starting")
# Set up progress bar appearance
bar_length = 20
half_filled = int(bar_length * 0.5)
half_bar = '█' * half_filled + ' ' * (bar_length - half_filled)
prefix = get_file_prefix(unsorted_name)
# Timer thread will handle displaying the 50% unsorted status
# print(f"Save: {half_bar} 50.0% | File: {prefix}") # Commented out as timer thread shows this
# print(get_separator_line(char='-'))
# Create timer thread
stop_timer = threading.Event()
timer_thread = threading.Thread(target=update_timer_energy_plots, args=(stop_timer, energy_plot_start_time, paced_mode))
timer_thread.daemon = True
timer_thread.start()
try:
# Generate the unsorted energy plot
if unsorted_energy_data:
unsorted_output = generate_unsorted_energy_plot(unsorted_energy_data, config.suffix)
else:
unsorted_output = generate_unsorted_energy_plot([], config.suffix)
# Stop the timer after first plot completes
stop_timer.set()
timer_thread.join(0.2)
# --- Write separator line and move down ---
# Timer status is on line X. Cursor is likely start of X+1.
# Write separator on line X+1. Use truncate_and_pad for consistency.
separator = get_separator_line(char='-')
sys.stdout.write(f"{truncate_and_pad_string(separator)}\n") # Write separator + newline
# Cursor is now at start of line X+2. Tqdm will start here.
sys.stdout.flush()
# --- End separator write + move down ---
# Generate the sorted energy plot
sorted_output = generate_sorted_energy_plot(config.suffix)
# Calculate final timing statistics
elapsed = time.time() - energy_plot_start_time
displayed_elapsed = max(elapsed, 0.01)
nominal_elapsed = max(elapsed, 0.02)
remaining = 0
rate = 2.0 / nominal_elapsed
rate = min(rate, 99.9)
time_info = f" [{displayed_elapsed:.2f}<{remaining:.2f}, {rate:.1f}file/s]" # Final stats
except Exception as e:
# Stop timer thread on error
if timer_thread.is_alive():
stop_timer.set()
timer_thread.join(0.2)
# Log the error
logger.error(f"Error during energy plot generation: {e}")
logger.error(traceback.format_exc())
print_status(f"Error during energy plot generation: {e}")
# Mark plots as failed
unsorted_output = None
sorted_output = None
# Show completion status
full_bar = '█' * bar_length # Keep this
final_prefix = get_file_prefix(sorted_name) # Use SORTED prefix
# --- Start Final Update Logic ---
# Move up 5 lines (1 init status + 1 init separator + 2 tqdm bars + 1 line below tqdm)
move_up_final = "\033[5A"
# Downward movement: We write status (line X), then separator (line X+1).
# We need to land below where the second tqdm bar was (line X+4).
# So, after writing separator on X+1, move down 3 lines.
move_down_final = "\033[4B" # Relative Move
sys.stdout.write(move_up_final) # Move up to initial status line (X)
# Create the final status string
content_string = f"Save: {full_bar} 100.0% | File: {final_prefix}{time_info}"
# Clear line X and write final status
sys.stdout.write(f"\r\033[2K{truncate_and_pad_string(content_string)}")
# Move to next line (X+1), clear and write separator
sys.stdout.write('\n') # Move cursor to start of line X+1
sys.stdout.write(f"\r\033[2K{get_separator_line(char='-')}") # Clear/write separator
# Move cursor down 4 lines relative to the current line (X+1) to land below tqdm area
sys.stdout.write(move_down_final)
sys.stdout.flush()
# --- End Final Update Logic ---
# Generate additional energy plots
process_trajectory_energy_plots(file_paths, include_angular_momentum=False)
# Display completion message
logger.info("Energy data processing complete.")
print_footer("Energy plots generated successfully.")
[docs]
def display_configuration_summary(config):
"""
Prints the execution banner and summarizes the configuration parameters.
Parameters
----------
config : Configuration
The configuration object containing parsed arguments and settings.
"""
# Display banner and parameter information
print_header("NSphere Plot Generation Tool")
# Display parameter values in a structured format
print_status("\nParameter values requested:\n")
print_status(f" Number of Particles: {config.npts}")
print_status(f" Number of Time Steps: {config.Ntimes}")
print_status(f" Number of Dynamical Times: {config.tfinal_factor}")
print_status(f" FPS: {config.fps} (frame duration: {config.duration:.1f}ms)")
print_status(f" Output Directory: results")
if config.file_tag:
print_status(f" Filename Tag: {config.file_tag}")
else:
print_status(f" Filename Tag: [none]")
if paced_mode:
print_status(f" Paced Mode: Enabled (section delay: 5.0s, progress bar delay: 2.0s)")
else:
print_status(f" Paced Mode: Disabled (fast mode - no delays)")
if enable_logging:
print_status(f" Logging: Full (log/nsphere_plot.log)")
else:
print_status(f" Logging: Errors and warnings only (log/nsphere_plot.log)")
print_status("")
[docs]
def main():
"""
Main function that orchestrates the visualization process.
"""
global suffix, start_snap, end_snap, step_snap, duration, args
global mass_snapshots, density_snapshots, psi_snapshots
global showing_help
# Create Configuration object (parses arguments and sets up parameters)
config = Configuration()
# If showing help, exit immediately
if showing_help:
return
# Display configuration summary
display_configuration_summary(config)
# Set global variables from configuration
args = config.args
suffix = config.suffix
start_snap = config.start_snap
end_snap = config.end_snap
step_snap = config.step_snap
duration = config.duration
# Create results directory
os.makedirs("results", exist_ok=True)
# Check if any "only" flags are active
only_flags_active = config.only_specific_visualizations()
file_paths = None
rank_decimated_data = None
# Only process phase space plots if specified or in normal mode
if config.args.phase_space or not only_flags_active:
# Skip if explicitly told not to generate phase space plots
if not config.args.no_phase_space:
print_header("Generating Phase Space Initial Histogram")
generate_initial_phase_histogram(config.suffix)
print_footer("Initial phase space histogram created successfully.")
# If phase_space_only is true, also generate the animation right away
if config.args.phase_space:
print_header("Generating Phase Space Animation")
generate_phase_space_animation(config.suffix, fps=config.fps)
# Add separator and completion message for phase space animation
sys.stdout.write(get_separator_line(char='-') + "\n")
print_status("Phase space animation generated successfully.")
# Log this message to file only, not to console
log_message("Phase space animation created successfully.", "info")
# Only process phase comparison if specified or in normal mode
if config.args.phase_comparison or not only_flags_active:
# Skip if explicitly told not to generate phase comparison
if not config.args.no_phase_comparison:
print_header("Generating Phase Space Comparison")
generate_comparison_plot(config.suffix)
print_footer("Phase space comparison created successfully.")
# --- Generate 1D Variable Distributions ---
if config.args.distributions or (not only_flags_active and not config.args.no_distributions):
# This function handles its own header, data loading, processing, and plotting.
process_variable_histograms(config)
# Footer is printed by the function or progress bar completion
# Process profile plots if requested or in normal mode
if config.args.profile_plots or not only_flags_active:
if not config.args.no_profile_plots:
# Setup file paths
file_paths = config.setup_file_paths()
print_header("Generating Profile Plots")
process_profile_plots(file_paths)
print_footer("Profile plots generated successfully.")
# Process trajectory plots if requested or in normal mode
if config.args.trajectory_plots or not only_flags_active:
if not config.args.no_trajectory_plots:
# Setup file paths if not already set
if file_paths is None:
file_paths = config.setup_file_paths()
print_header("Generating Trajectory and Diagnostic Plots")
# Count the total plots to be generated from both functions
trajectory_count = 3 # trajectories, single_trajectory, lowest_l_3panel
energy_count = 3 # energy_vs_time, angular_momentum_vs_time, energy_compare
# Initialize combined progress tracking for all trajectory-related plots
total_plots = trajectory_count + energy_count
start_combined_progress("trajectory_plots", total_plots)
# Use a modified process_trajectory_plots function that updates the combined tracker
process_trajectory_plots(file_paths)
print_footer("Trajectory and diagnostic plots generated successfully.")
# Generate 2D histograms if requested or in normal mode
if getattr(config.args, '2d_histograms', False) or not only_flags_active:
if not config.args.no_histograms:
generate_all_2d_histograms(config, rank_decimated_data)
# Generate convergence test plots if requested or in normal mode
if config.args.convergence_tests or not only_flags_active:
if not config.args.no_convergence_tests:
print_header("Generating Convergence Test Plots")
# The generate_convergence_test_plots function uses combined progress tracking
generate_convergence_test_plots(config.suffix)
print_footer("Convergence test plots generated successfully.")
# Exit early if only specific, already-handled visualizations were requested
# to avoid unnecessary rank file processing
need_to_process_rank_files = config.need_to_process_rank_files()
if only_flags_active and not need_to_process_rank_files:
print_header("Visualization Complete")
print_footer("All requested visualizations completed successfully.")
return
# Process rank files first if they'll be needed for animations or energy plots
rank_decimated_data = None
unsorted_energy_data = None # Initialize to ensure it's defined
if need_to_process_rank_files:
print_header("Processing Particle Snapshot Data")
# Setup file paths if not already set
if file_paths is None:
file_paths = config.setup_file_paths()
# Process rank files and prepare data for animations and energy plots
# The function will print its own completion message
_, unsorted_energy_data = process_rank_files(config.suffix, config.start_snap, config.end_snap, config.step_snap)
rank_decimated_data = None
# Process energy plots if requested or in normal mode
if config.args.energy_plots or (not only_flags_active and not config.args.no_energy_plots):
generate_all_energy_plots(config, file_paths, unsorted_energy_data)
# Process animations if requested or in normal mode
if config.args.animations or (not only_flags_active and not config.args.no_animations):
print_header("Generating Animations")
# Use parallel processing internally for animations; manage progress sequentially here
# This provides parallel processing with clean progress output
generate_all_1D_animations(config.suffix, config.duration)
# Add completion message for profile animations (no separator needed here)
print_status("Mass, density, and psi animations generated successfully.")
# Generate phase space animation if not explicitly disabled and not already generated
if not config.args.no_phase_space and not config.args.phase_space: # Skip if already generated in phase-space-only mode
print_header("Generating Phase Space Animation")
generate_phase_space_animation(config.suffix, fps=config.fps)
# Add completion message for phase space animation (separator is already added in the function)
print_status("Phase space animation generated successfully.")
# Log this message to file only, not to console
log_message("Phase space animation created successfully.", "info")
# Log this message to file only, not to console
log_message("All animations generated successfully.", "info")
# If doing only specific visualizations, exit now
if only_flags_active:
print_header("Visualization Complete")
# Log completion message to file only, not to console
log_message("All Visualizations Complete", "info")
return
else:
print_header("All Visualizations Complete")
# Log completion message to file only, not to console
log_message("All visualizations generated successfully. Results saved to 'results' directory.", "info")
[docs]
def run_main():
"""
Main entry point that runs the application with the cursor hidden.
Ensures cursor is always restored when the program exits.
"""
try:
# Hide cursor at program start
hide_cursor()
main()
except KeyboardInterrupt:
print("\nProcess interrupted by user")
sys.exit(1)
except Exception as e:
# Re-raise any exceptions after showing cursor
logger.error(f"Error: {str(e)}")
logger.error(traceback.format_exc()) # Log full traceback
# Print error to console as well for visibility
print_status(f"\nAn error occurred: {str(e)}")
raise e
finally:
# Always ensure cursor is visible before exiting
show_cursor()
if __name__ == '__main__':
run_main()