import base64
import io
import json
import logging
import logging.config
import os
import re
from typing import Any, Dict, Iterable, List, Optional
import yaml
from blessed import Terminal
from PIL import Image
logger = logging.getLogger(__name__)
# Initialize the terminal object
_term = Terminal()
[docs]
def load_config(config_path: str, config_name: Optional[str] = None) -> dict:
"""Load a YAML configuration file or a specific configuration from a folder or file.
Args:
config_path (str): Path to the YAML configuration file or folder containing configuration files.
config_name (Optional[str]): Name of the target configuration file (if config_path is a folder)
or the key within the YAML file (if config_path is a file).
If omitted and config_path is a file, the entire file is loaded.
Returns:
dict: Dictionary containing the configuration data.
Raises:
FileNotFoundError: If the specified `config_path` does not exist.
yaml.YAMLError: If the YAML file is malformed or cannot be parsed.
ValueError: If the `config_name` is not found in the configuration.
"""
if os.path.isdir(config_path):
# If config_path is a folder, config_name must be provided
if config_name is None:
raise ValueError(
"config_name must be provided when config_path is a folder."
)
# Construct the full path to the config file
config_file_path = os.path.join(config_path, f"{config_name}.yaml")
if not os.path.exists(config_file_path):
raise FileNotFoundError(
f"Configuration file '{config_file_path}' not found in folder '{config_path}'."
)
with open(config_file_path, "r") as file:
return yaml.safe_load(file)
elif os.path.isfile(config_path):
# If config_path is a file, load the YAML
with open(config_path, "r") as file:
config_data = yaml.safe_load(file)
# If config_name is provided, extract the specific config
if config_name is not None:
if config_name not in config_data:
raise ValueError(
f"Configuration key '{config_name}' not found in file '{config_path}'."
)
return config_data[config_name]
# If config_name is omitted, return the entire config
return config_data
else:
raise FileNotFoundError(f"Path '{config_path}' does not exist.")
[docs]
def load_prompts(prompts_path: str, which_model: str) -> dict:
"""Load prompts from a YAML file and return prompts for a specific model.
Args:
prompts_path (str): Path to the YAML file containing prompts.
which_model (str): Key specifying which model's prompts to load.
Returns:
dict: Dictionary containing prompts for the specified model.
Raises:
KeyError: If the specified `which_model` key is not found in the YAML file.
"""
prompt_templates = load_config(prompts_path, which_model)
return prompt_templates
[docs]
def colorstring(
message: Any, # Accept any type of input
color: Optional[str] = "green",
bold: bool = False,
) -> str:
"""
Returns a colored string using either ANSI escape codes or blessed terminal capabilities.
:param message: The message to be colored. Can be of any type (e.g., str, int, float, bool).
:param color: The color to apply. Supported colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white',
and their bright variants: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', 'bright_blue',
'bright_magenta', 'bright_cyan', 'bright_white'.
:param bold: If True, applies bold styling to the text (only applicable when use_ansi=True).
:return: A string with the specified color and styling.
"""
color_mapping = {
"black": _term.black,
"blue": _term.blue,
"cyan": _term.cyan,
"green": _term.green,
"magenta": _term.magenta,
"red": _term.red,
"white": _term.white,
"yellow": _term.yellow,
"bright_black": _term.bright_black,
"bright_blue": _term.bright_blue,
"bright_cyan": _term.bright_cyan,
"bright_green": _term.bright_green,
"bright_magenta": _term.bright_magenta,
"bright_red": _term.bright_red,
"bright_white": _term.bright_white,
"bright_yellow": _term.bright_yellow,
}
# Convert the message to a string
message_str = str(message)
color_func = color_mapping.get(color.lower(), _term.white)
styled_message = color_func(message_str)
if bold:
styled_message = _term.bold(styled_message)
return styled_message
[docs]
def cprint(message: Any, color: Optional[str] = "green", **kwargs) -> None:
"""
Prints a colored string using blessed terminal capabilities.
:param message: The message to be colored. Can be of any type (e.g., str, int, float, bool).
:param color: The color to apply. Supported colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'.
:param kwargs: Additional keyword arguments to pass to the `print` function (e.g., `end`, `sep`, `file`, `flush`).
"""
print(colorstring(message, color), **kwargs)
[docs]
def get_image_paths(path: str | List[str]) -> List[str]:
"""
Get image file paths from a specified folder, a single image file, or a list of image paths.
Parameters:
path (Union[str, List[str]]): The path to the folder, the single image file, or a list of image paths.
Returns:
List[str]: A list of image file paths.
Raises:
ValueError: If any path does not exist or is not a valid image file or folder of images.
"""
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif"}
def _get_single_image_path(p: str) -> List[str]:
if not os.path.exists(p):
raise ValueError(f"The path '{p}' does not exist.")
if os.path.isfile(p):
if os.path.splitext(p)[1].lower() in valid_extensions:
return [p]
raise ValueError(f"The file '{p}' is not a recognized image file.")
if os.path.isdir(p):
return [
os.path.join(p, f)
for f in os.listdir(p)
if os.path.splitext(f)[1].lower() in valid_extensions
]
raise ValueError(f"The path '{p}' is neither a file nor a directory of image.")
if isinstance(path, str):
return _get_single_image_path(path)
elif isinstance(path, list):
image_paths = []
for p in path:
image_paths.extend(_get_single_image_path(p))
return image_paths
else:
raise ValueError("The input must be a string or a list of strings.")
[docs]
def image_to_base64(
image: Image.Image | str,
quality: int = 85,
max_resolution: tuple = (448, 448),
img_format: str = "WEBP",
) -> str:
"""
Convert the image to a base64 encoded string.
:param image: PIL Image object or the path to the image file.
:param quality: Compression quality (0-100) for WebP format. Higher values mean better quality but larger size.
:param max_resolution: Optional maximum resolution (width, height) to fit the image within while preserving aspect ratio.
:param img_format: Image format to use for encoding. Default is "WEBP".
:return: Base64 encoded string of the image.
"""
if isinstance(image, str):
# If the image is a string, assume it's a path and open it
image = Image.open(image)
# Convert the image to RGB mode if it's in RGBA mode
if image.mode == "RGBA":
image = image.convert("RGB")
# Resize the image while preserving aspect ratio
if max_resolution:
original_width, original_height = image.size
max_width, max_height = max_resolution
# Calculate the new dimensions while preserving aspect ratio
ratio = min(max_width / original_width, max_height / original_height)
new_width = int(original_width * ratio)
new_height = int(original_height * ratio)
# Resize the image
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Remove metadata (EXIF, etc.)
if "exif" in image.info:
del image.info["exif"]
# Save the image to a BytesIO buffer as WebP with specified quality
buffered = io.BytesIO()
image.save(buffered, format=img_format, quality=quality)
# Return the Base64 encoded string
return base64.b64encode(buffered.getvalue()).decode("utf-8")
[docs]
def find_files_with_extensions(
directory_path: str,
extensions: str | Iterable[str],
return_all: bool = False,
) -> str | List[str] | None:
"""
Find files with the specified extensions in the given directory.
If `return_all` is False (default), returns the first matching file based on priority.
If `return_all` is True, returns a list of all matching files, sorted by priority.
Args:
directory_path (str): Path to the directory to search.
extensions (Union[str, List[str]]): A single extension or a list of extensions.
return_all (bool): If True, return all matching files; otherwise, return the first match.
Returns:
Union[str, List[str], None]: A single file path, a list of file paths, or None if no files are found.
"""
# Ensure extensions is a list for consistent handling
if isinstance(extensions, str):
extensions = [extensions]
# List to store all matching files
all_matching_files = []
try:
# Walk through the directory
for root, _, files in os.walk(directory_path):
for file in files:
# Check if the file ends with any of the specified extensions
for ext in extensions:
if file.endswith(f".{ext}"):
full_path = os.path.join(root, file)
if not return_all:
# Return the first matching file based on priority
return full_path
else:
# Add to the list of matching files
all_matching_files.append((ext, full_path))
except FileNotFoundError:
print(f"Error: The directory '{directory_path}' does not exist.")
return None if not return_all else []
except PermissionError:
print(f"Error: Permission denied to access the directory '{directory_path}'.")
return None if not return_all else []
if return_all:
# Sort the matching files by extension priority
all_matching_files.sort(key=lambda x: extensions.index(x[0]))
# Return only the file paths (without the extension used for sorting)
return [file_path for _, file_path in all_matching_files]
else:
# Return None if no matching file is found
return None
[docs]
def parse_json_response(response: str) -> Dict[str, Any]:
"""Parse JSON response from VLM, handling potential errors"""
try:
# Normalize bracket usage for JSON parsing
response = response.replace("{{", "{").replace("}}", "}")
# Extract JSON content from response
json_start = response.find("{")
json_end = response.rfind("}") + 1
json_str = response[json_start:json_end]
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {str(e)}")
logger.debug(f"Original response: {response}")
return {}
except Exception as e:
logger.error(f"Unexpected error parsing response: {str(e)}")
return {}
[docs]
def parse_design_goal(design_goal_input: str) -> str:
"""
Parse the design goal input, which can be either a JSON file or plain text.
If it's a JSON file, extract the 'text' field.
Args:
design_goal_input (str): Path to a JSON file or plain text.
Returns:
str: The design goal text.
"""
if os.path.isfile(design_goal_input):
with open(design_goal_input, "r") as f:
try:
data = json.load(f)
return data.get("text", "")
except json.JSONDecodeError:
logger.error("The provided file is not a valid JSON.")
raise json.JSONDecodeError("The provided file is not a valid JSON.")
return design_goal_input
[docs]
def setup_logging(
log_level: str = "INFO",
log_file: Optional[str] = None,
log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
) -> None:
"""
Configure the logging system.
This function sets up logging with the following behavior:
- Console logs retain ANSI escape codes for colored output.
- File logs (if specified) have ANSI escape codes removed for cleaner output.
- Mutes httpx INFO logs while keeping the global log_level at INFO.
Args:
log_level (str): The logging level (e.g., "INFO", "DEBUG"). Defaults to "INFO".
log_file (Optional[str]): Path to the log file. If None, no file logging is performed.
log_format (str): The format string for log messages. Defaults to "%(asctime)s - %(name)s - %(levelname)s - %(message)s".
Returns:
None
"""
# Define formatters
formatters = {
"console": {
"format": log_format,
},
"file": {
"()": ANSIStrippingFormatter, # Remove ANSI escape codes for file logs
"format": log_format,
},
}
# define handlers (default enable console)
handlers = {
"console": {
"level": log_level.upper(),
"class": "logging.StreamHandler",
"formatter": "console",
"stream": "ext://sys.stdout",
}
}
# only add file handler when log_file is provided
if log_file:
handlers["file"] = {
"level": log_level.upper(),
"class": "logging.FileHandler",
"filename": log_file,
"formatter": "file",
"mode": "w",
}
# configure logging
logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": formatters,
"handlers": handlers,
"loggers": {
"": {
"level": log_level.upper(),
"handlers": list(handlers.keys()),
"propagate": True, # Logs propagate to root logger
},
# Add a custom logger for httpx to mute INFO logs
"httpx": {
"level": "WARNING", # Set httpx logger to WARNING to mute INFO logs
"handlers": list(handlers.keys()),
"propagate": False, # Prevent httpx logs from propagating to root logger
},
},
}
logging.config.dictConfig(logging_config)
logging.info("Logging configuration is set up.")
if __name__ == "__main__":
cprint("This is a red message", "red")
cprint("This is a green message", "green")
cprint("This is a blue message", "blue")
cprint("This is a yellow message", "yellow")
cprint("This is a magenta message", "magenta")
cprint("This is a cyan message", "cyan")
cprint("This is a black message", "black")
cprint("This is a white message", "white")
# Configure logging
setup_logging()
# Example usage:
logger.info(colorstring("This is a red message", "red"))
logger.info(colorstring("This is a green message", "green"))
logger.info(colorstring("This is a blue message", "blue"))
logger.info(colorstring("This is a yellow message", "yellow"))
logger.info(colorstring("This is a magenta message", "magenta"))
logger.info(colorstring("This is a cyan message", "cyan"))
logger.info(colorstring("This is a black message", "black"))
logger.info(colorstring("This is a white message", "white"))