Source code for feedback.visual_feedback

import argparse
import logging
from typing import Any, Dict, List

from cicada.core import model
from cicada.core.basics import PromptBuilder


logger = logging.getLogger(__name__)


[docs] class VisualFeedback(model.MultiModalModel): def __init__( self, api_key, api_base_url, model_name, org_id, prompt_templates, **model_kwargs, ): super().__init__( api_key, api_base_url, model_name, org_id, **model_kwargs, ) self.visual_feedback_prompts = prompt_templates
[docs] def generate_feedback_paragraph( self, design_goal: str, reference_images: List[str] | None, rendered_images: List[str], ) -> str: """ Generate a feedback paragraph comparing the rendered object with the design goal and reference images. Focus on geometry, shape, and physical feasibility. :param design_goal: Text description of the design specifications. :param reference_images: List of byte data for reference images. :param rendered_images: List of byte data for rendered object images. :return: A paragraph of feedback highlighting hits and misses. """ # Use the user prompt template and format it with the design goal prompt = self.visual_feedback_prompts["user_prompt_template"].format( text=design_goal ) pb = PromptBuilder() pb.add_system_message(self.visual_feedback_prompts["system_prompt_template"]) pb.add_user_message(prompt) if reference_images: pb.add_text("The following is a set of reference images:") pb.add_images(reference_images) pb.add_text("The following is a set of rendered object images:") pb.add_images(rendered_images) # Query the VLM with images and prompt response = self.query(prompt_builder=pb, stream=self.stream)["content"] # Extract and return the feedback paragraph feedback = response.strip() return feedback
[docs] def parse_args() -> Dict[str, Any]: """ Parse command line arguments. Returns: Dict[str, Any]: Parsed arguments. """ parser = argparse.ArgumentParser(description="Visual Feedback Model") parser.add_argument( "--config", default="config.yaml", help="Path to the configuration YAML file" ) parser.add_argument( "--prompts", default="prompts", help="Path to the prompts YAML file or folder" ) parser.add_argument( "--design_goal", required=True, help="Text description of the design goal or path to a JSON file containing the design goal", ) parser.add_argument( "--reference_images", help="Path to the folder containing reference images" ) parser.add_argument( "--rendered_images", required=True, help="Path to the folder containing rendered object images", ) return parser.parse_args()
if __name__ == "__main__": from cicada.core.utils import ( cprint, load_config, load_prompts, parse_design_goal, setup_logging, ) args = parse_args() config = load_config(args.config, "visual_feedback") prompt_templates = load_prompts(args.prompts, "visual_feedback") # Initialize the VisualFeedback visual_feedback = VisualFeedback( config["api_key"], config.get("api_base_url"), config.get("model_name", "gpt-4"), config.get("org_id"), prompt_templates, **config.get("model_kwargs", {}), ) # Parse the design goal design_goal = parse_design_goal(args.design_goal) # Generate feedback feedback = visual_feedback.generate_feedback_paragraph( design_goal, args.reference_images, args.rendered_images ) # Print the feedback cprint(feedback, "cyan")