Source code for coding.code_cache

import json
import logging
import os
import sqlite3
from sqlite3 import Connection

logger = logging.getLogger(__name__)


[docs] class CodeCache: def __init__(self, db_file="coding.db"): # Ensure the directory exists db_dir = os.path.dirname(db_file) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir) self.db_file = db_file self.connection_pool = [] self.initialize_database() def _get_connection(self) -> Connection: if len(self.connection_pool) == 0: return sqlite3.connect(self.db_file) return self.connection_pool.pop() def _return_connection(self, conn: Connection): if conn: self.connection_pool.append(conn)
[docs] def initialize_database(self): conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS session ( id INTEGER PRIMARY KEY AUTOINCREMENT, design_goal TEXT NOT NULL, parent_session_id INTEGER, coding_plan TEXT, created_at TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (parent_session_id) REFERENCES session(id) ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS iteration ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL, code TEXT NOT NULL, feedback TEXT, is_correct INTEGER DEFAULT 0, -- 1 for True, 0 for False is_runnable INTEGER DEFAULT 0, -- 1 for True, 0 for False created_at TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (session_id) REFERENCES session(id) ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS error ( id INTEGER PRIMARY KEY AUTOINCREMENT, iteration_id INTEGER NOT NULL, error_type TEXT CHECK(error_type IN ('syntax', 'runtime')), error_message TEXT, error_line INTEGER, created_at TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), FOREIGN KEY (iteration_id) REFERENCES iteration(id) ) """ ) conn.commit() self._return_connection(conn) logger.info("Database tables initialized.")
# Session API
[docs] def get_session(self, session_id, fields=None): conn = self._get_connection() cursor = conn.cursor() if fields is None: fields = "*" else: fields = ", ".join(fields) cursor.execute( f""" SELECT {fields} FROM session WHERE id = ? """, (session_id,), ) result = cursor.fetchone() self._return_connection(conn) if result: return dict( zip([description[0] for description in cursor.description], result) ) else: logger.warning(f"No session found with ID: {session_id}") return None
[docs] def insert_session( self, design_goal, parent_session_id=None, coding_plan=None ): # Added coding_plan parameter conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ INSERT INTO session (design_goal, parent_session_id, coding_plan) -- Updated INSERT statement VALUES (?, ?, ?) """, (design_goal, parent_session_id, coding_plan), ) conn.commit() session_id = cursor.lastrowid self._return_connection(conn) logger.info(f"Session inserted with ID: {session_id}") return session_id
[docs] def update_session( self, session_id, design_goal: str | None = None, coding_plan: dict | None = None, ): # Make design_goal optional and denote coding_plan as dict|None if design_goal is None and coding_plan is None: logger.warning("Either design_goal or coding_plan must be provided.") return conn = self._get_connection() cursor = conn.cursor() if design_goal is not None and coding_plan is not None: cursor.execute( """ UPDATE session SET design_goal = ?, coding_plan = ? WHERE id = ? """, (design_goal, json.dumps(coding_plan), session_id), ) elif design_goal is not None: cursor.execute( """ UPDATE session SET design_goal = ? WHERE id = ? """, (design_goal, session_id), ) elif coding_plan is not None: cursor.execute( """ UPDATE session SET coding_plan = ? WHERE id = ? """, (json.dumps(coding_plan), session_id), ) conn.commit() self._return_connection(conn) logger.info(f"Session with ID {session_id} updated.")
[docs] def get_iteration(self, iteration_id, fields=None): conn = self._get_connection() cursor = conn.cursor() if fields is None: fields = "*" else: fields = ", ".join(fields) cursor.execute( f""" SELECT {fields} FROM iteration WHERE id = ? """, (iteration_id,), ) result = cursor.fetchone() self._return_connection(conn) if result: return dict( zip([description[0] for description in cursor.description], result) ) else: logger.warning(f"No iteration found with ID: {iteration_id}") return None
[docs] def get_iterations(self, session_id, fields=None): conn = self._get_connection() cursor = conn.cursor() if fields is None: fields = "*" else: fields = ", ".join(fields) cursor.execute( f""" SELECT {fields} FROM iteration WHERE session_id = ? """, (session_id,), ) results = cursor.fetchall() self._return_connection(conn) if results: return [ dict( zip([description[0] for description in cursor.description], result) ) for result in results ] else: logger.warning(f"No iterations found for session ID: {session_id}") return None
[docs] def insert_iteration(self, session_id, code, feedback=None): conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ INSERT INTO iteration (session_id, code, feedback, is_correct, is_runnable) VALUES (?, ?, ?, 0, 1) -- 0 for False, 1 for True """, (session_id, code, feedback), ) conn.commit() iteration_id = cursor.lastrowid self._return_connection(conn) logger.info(f"Iteration inserted with ID: {iteration_id}") return iteration_id
[docs] def update_iteration( self, iteration_id, code=None, feedback=None, is_correct=None, is_runnable=None ): conn = self._get_connection() cursor = conn.cursor() set_clauses = [] params = [] if code is not None: set_clauses.append("code = ?") params.append(code) if feedback is not None: set_clauses.append("feedback = ?") params.append(feedback) if is_correct is not None: set_clauses.append("is_correct = ?") params.append(int(is_correct)) if is_runnable is not None: set_clauses.append("is_runnable = ?") params.append(int(is_runnable)) if set_clauses: query = f"UPDATE iteration SET {', '.join(set_clauses)} WHERE id = ?" params.append(iteration_id) cursor.execute(query, tuple(params)) conn.commit() self._return_connection(conn) logger.info(f"Iteration with ID {iteration_id} updated.")
# Error API
[docs] def get_errors(self, iteration_id, fields=None): conn = self._get_connection() cursor = conn.cursor() if fields is None: fields = "*" else: fields = ", ".join(fields) cursor.execute( f""" SELECT {fields} FROM error WHERE iteration_id = ? """, (iteration_id,), ) results = cursor.fetchall() self._return_connection(conn) if results: return [ dict( zip([description[0] for description in cursor.description], result) ) for result in results ] else: logger.warning(f"No errors found for iteration ID: {iteration_id}") return None
[docs] def insert_error(self, iteration_id, error_type, error_message, error_line=None): conn = self._get_connection() cursor = conn.cursor() # Set is_runnable to False if any error is present cursor.execute( """ UPDATE iteration SET is_runnable = 0 WHERE id = ? """, (iteration_id,), ) cursor.execute( """ INSERT INTO error (iteration_id, error_type, error_message, error_line) VALUES (?, ?, ?, ?) """, (iteration_id, error_type, error_message, error_line), ) conn.commit() error_id = cursor.lastrowid self._return_connection(conn) logger.info(f"Error inserted with ID: {error_id}") return error_id
[docs] def update_error(self, error_id, error_message, error_line=None): conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ UPDATE error SET error_message = ?, error_line = ? WHERE id = ? """, (error_message, error_line, error_id), ) conn.commit() self._return_connection(conn) logger.info(f"Error with ID {error_id} updated.")
# Additional Method
[docs] def get_session_history(self, session_id): conn = self._get_connection() cursor = conn.cursor() cursor.execute( """ SELECT * FROM session WHERE id = ? """, (session_id,), ) session = cursor.fetchone() if session is None: self._return_connection(conn) return None cursor.execute( """ SELECT * FROM iteration WHERE session_id = ? """, (session_id,), ) iterations = cursor.fetchall() for iteration in iterations: iteration_id = iteration[0] cursor.execute( """ SELECT * FROM error WHERE iteration_id = ? """, (iteration_id,), ) errors = cursor.fetchall() iteration += (errors,) self._return_connection(conn) session += (iterations,) return session
[docs] def get_session_id(self, identifier, type_="iteration"): conn = self._get_connection() cursor = conn.cursor() if type_ == "iteration": cursor.execute( """ SELECT session_id FROM iteration WHERE id = ? """, (identifier,), ) elif type_ == "error": cursor.execute( """ SELECT session_id FROM iteration WHERE id = ( SELECT iteration_id FROM error WHERE id = ? ) """, (identifier,), ) else: self._return_connection(conn) logger.error(f"Invalid type: {type_}") return None result = cursor.fetchone() self._return_connection(conn) if result: return result[0] else: logger.warning(f"No session found for {type_} ID: {identifier}") return None
[docs] def close(self): for conn in self.connection_pool: if conn: conn.close() self.connection_pool.clear() logger.info("All database connections closed and cleaned up.")
# Usage example if __name__ == "__main__": from cicada.common.utils import setup_logging setup_logging() # Create an instance of CodeCache code_cache = CodeCache("/tmp/cicada/coding.db") # Insert a new session session_id_1 = code_cache.insert_session("First Test Session") logger.info(f"Created session with ID: {session_id_1}") # Insert code snippets for the first session code_id_1 = code_cache.insert_iteration( session_id_1, 'print("Hello, World!")', "Looks good" ) code_id_2 = code_cache.insert_iteration( session_id_1, 'print("Goodbye!")', "Syntax error" ) # Insert errors for the second iteration code_cache.insert_error(code_id_2, "syntax", "IndentationError", 2) # After fixing errors, update iteration flags # Assuming errors are fixed code_cache.update_iteration(code_id_2, is_runnable=True) # Validate and set is_correct to True # Assuming external validation confirms correctness code_cache.update_iteration(code_id_1, is_correct=True) # Retrieve and print iterations with their flags iterations = code_cache.get_iterations( session_id_1, fields=["id", "code", "is_correct", "is_runnable"] ) for iteration in iterations: logger.info( f"Iteration ID: {iteration['id']}, Code: {iteration['code']}, Correct: {bool(iteration['is_correct'])}, Runnable: {bool(iteration['is_runnable'])}" ) # Clean up resources code_cache.close()