Source code for feedback.visual_qa

import logging
from typing import List

from cicada.core import model
from cicada.core.basics import PromptBuilder
from cicada.core.utils import get_image_paths, image_to_base64

logger = logging.getLogger(__name__)


[docs] def load_images(path: str) -> List[str]: """ Load images from a folder and convert them to base64. Args: path (str): Path to the folder containing images. Returns: List[bytes]: List of images in base64 format. """ if path: return [image_to_base64(img) for img in get_image_paths(path)] else: return []
[docs] class VisualQA(model.MultiModalModel): def __init__( self, api_key, api_base_url, model_name, org_id, prompt_templates, **model_kwargs, ): """ Initialize the VisualQA model. Args: api_key (str): API key for the model. api_base_url (str): Base URL for the API. model_name (str): Name of the model. org_id (str): Organization ID. prompt_templates (dict): Dictionary of prompt templates. **model_kwargs: Additional keyword arguments for the model. """ super().__init__( api_key, api_base_url, model_name, org_id, **model_kwargs, ) self.visual_qa_prompts = prompt_templates
[docs] def generate_questions( self, design_goal: str, reference_images: List[str] | None, num_questions: int = 5, ) -> List[str]: """ Generate questions based on the design goal and images. Args: design_goal (str): Text description of the design goal. reference_images (List[str]): List of reference images paths. num_questions (int): Number of questions to generate. Returns: List[str]: List of generated questions. """ prompt = self.visual_qa_prompts["question_generation"][ "prompt_template" ].format(design_goal=design_goal, num_questions=num_questions) pb = PromptBuilder() pb.add_system_message( self.visual_qa_prompts["question_generation"]["system_prompt"] ) pb.add_user_message(prompt) if reference_images: pb.add_text("The following is a set of reference images:") pb.add_images(reference_images) response = self.query(prompt_builder=pb, stream=self.stream)["content"] questions = [q.strip() for q in response.split("\n") if q.strip()] return questions
[docs] def generate_answers( self, design_goal: str, questions: List[str], reference_images: List[str] | None, rendered_images: List[str], ) -> dict: """ Generate answers to the provided questions based on the design goal and images. Args: design_goal (str): Text description of the design goal. questions (List[str]): List of questions to answer. reference_images (List[str]): List of reference images in str. rendered_images (List[str]): List of rendered images in str. Returns: dict: Dictionary of questions and their corresponding answers. """ questions_text = "\n".join([f"{i+1}. {q}" for i, q in enumerate(questions)]) prompt = self.visual_qa_prompts["answer_generation"]["prompt_template"].format( design_goal=design_goal, questions=questions_text ) # construct using PromptBuilder pb = PromptBuilder() pb.add_system_message( self.visual_qa_prompts["answer_generation"]["system_prompt"] ) pb.add_user_message(prompt) # Add the formatted prompt to the PromptBuilder if reference_images: pb.add_text("The followings are reference images of design goal:") pb.add_images(reference_images) pb.add_text("The following is a set of rendered object images:") pb.add_images(rendered_images) response = self.query(prompt_builder=pb, stream=self.stream)["content"] return response
[docs] def automated_qa( self, design_goal: str, reference_images: List[str] | None, rendered_images: List[str], num_questions: int = 5, ) -> dict: """ Generate questions and then generate answers to those questions. Args: design_goal (str): Text description of the design goal. reference_images (List[str]): List of reference images in str. rendered_images (List[str]): List of rendered images in str. num_questions (int): Number of questions to generate. Returns: dict: Dictionary containing the generated questions and answers. """ questions = self.generate_questions( design_goal=design_goal, reference_images=reference_images, num_questions=num_questions, ) answers = self.generate_answers( design_goal, questions, reference_images, rendered_images ) return {"questions": questions, "answers": answers}
if __name__ == "__main__": import argparse import json from cicada.core.utils import ( colorstring, cprint, get_image_paths, image_to_base64, load_config, load_prompts, parse_design_goal, setup_logging, ) setup_logging() parser = argparse.ArgumentParser(description="Visual QA Model") parser.add_argument( "--config", default="config.yaml", help="Path to the configuration YAML file" ) parser.add_argument( "--prompts", default="prompts", help="Path to the prompts YAML file or folder" ) parser.add_argument( "--mode", required=True, choices=["question_generation", "answer_generation", "automated_qa"], help="Operating mode: question_generation, answer_generation, or automated_qa", ) parser.add_argument( "--design_goal", help="Text description of the design goal or path to a JSON file containing the design goal", ) parser.add_argument("--questions", nargs="+", help="List of questions to answer") parser.add_argument( "--reference_images", help="Path to the folder containing reference images" ) parser.add_argument( "--rendered_images", help="Path to the folder containing rendered object images" ) parser.add_argument( "--num_questions", type=int, default=5, help="Number of questions to generate" ) args = vars(parser.parse_args()) config = load_config(args["config"], "visual_qa") prompt_templates = load_prompts(args["prompts"], "visual_qa") visual_qa = VisualQA( config["api_key"], config.get("api_base_url"), config.get("model_name", "gpt-4"), config.get("org_id"), prompt_templates, **config.get("model_kwargs", {}), ) if args["mode"] == "question_generation": if not args["design_goal"]: raise ValueError("Design goal is required for question generation.") design_goal = parse_design_goal(args["design_goal"]) questions = visual_qa.generate_questions( design_goal=design_goal, reference_images=args["reference_images"], num_questions=args["num_questions"], ) cprint(json.dumps(questions, indent=4), "cyan") elif args["mode"] == "answer_generation": if not args["questions"]: raise ValueError("Questions are required for answer generation.") questions = args["questions"] design_goal = ( parse_design_goal(args["design_goal"]) if args["design_goal"] else "" ) answers = visual_qa.generate_answers( design_goal=design_goal, questions=questions, reference_images=args["reference_images"], rendered_images=args["rendered_images"], ) cprint(json.dumps(answers, indent=4), "cyan") elif args["mode"] == "automated_qa": if not args["design_goal"]: raise ValueError("Design goal is required for automated QA.") design_goal = parse_design_goal(args["design_goal"]) qa_result = visual_qa.automated_qa( design_goal=design_goal, reference_images=args["reference_images"], rendered_images=args["rendered_images"], num_questions=args["num_questions"], ) cprint(json.dumps(qa_result, indent=4), "cyan") else: raise ValueError("Invalid mode selected.")