import asyncio import os from getpass import getpass from pathlib import Path from typing import Dict, List, Optional import boto3 import requests from dotenv import load_dotenv from playwright.async_api import Page, async_playwright from sqlalchemy import create_engine, text from sqlalchemy.engine.base import Engine import time from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError load_dotenv(override=True) def download_image(image_url: str, image_path: str, timeout: int = 5) -> str: """ Downloads an image from a provided URL and saves it to a local path. Args: image_url (str): URL of the image to download. image_path (str): Local path where the image will be saved, including the image file name. timeout (int): Maximum time, in seconds, to wait for the server's response. Default is 5 seconds. Raises: HTTPError: If there was an unsuccessful HTTP response. Timeout: If the request times out. Returns: str: Local path where the image has been saved. """ response = requests.get(image_url, timeout=timeout) response.raise_for_status() # Raise exception if invalid response. with open(image_path, "wb") as f: f.write(response.content) return image_path def upload_to_s3( image_path: str, bucket: str, s3_image_name: str, aws_access_key_id: str, aws_secret_access_key: str, region_name: str, ) -> str: """ Uploads an image file to an S3 bucket and returns the URL of the uploaded file. Args: image_path (str): Path to the image file to upload. bucket (str): Name of the S3 bucket to upload to. s3_image_name (str): Name to give to the file once it's uploaded. aws_access_key_id (str): AWS access key ID. aws_secret_access_key (str): AWS secret access key. region_name (str): The name of the AWS region where the S3 bucket is located. Returns: str: URL of the uploaded image in the S3 bucket. Raises: ClientError: If there was an error uploading the file to S3. """ s3 = boto3.client( "s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key ) with open(image_path, "rb") as f: s3_path = "blog_post_covers/" + s3_image_name # prepend the S3 'folder' name s3.upload_fileobj(f, bucket, s3_path) # remove the image from the local filesystem os.remove(image_path) url = f"https://{bucket}.s3.{region_name}.amazonaws.com/{s3_path}" return url async def login_to_discord( page: Page, server_id: str, channel_id: str, email: Optional[str] = None, password: Optional[str] = None, auth_code: Optional[str] = None, ) -> None: """ Log in to Discord via a Playwright browser page. Args: page (Page): Playwright browser page instance. server_id (str): Discord server ID to navigate to after login. channel_id (str): Discord channel ID to navigate to after login. email (Optional[str], optional): Email to use for logging in to Discord. Defaults to None. password (Optional[str], optional): Password to use for logging in to Discord. Defaults to None. auth_code (Optional[str], optional): Authentication code to use for logging in to Discord. Defaults to None. Raises: TimeoutError: If any of the page actions do not complete within the default timeout period. """ discord_channel_url = f"https://discord.com/channels/{server_id}/{channel_id}" await page.goto(discord_channel_url) await page.get_by_role("button", name="Continue in browser").click() await page.get_by_label("Email or Phone Number*").click() if not email: email = input("Please enter your email: ") await page.get_by_label("Email or Phone Number*").fill(email) await page.get_by_label("Email or Phone Number*").press("Tab") if not password: password = getpass("Please enter your password: ") await page.get_by_label("Password*").fill(password) await page.get_by_role("button", name="Log In").click() if not auth_code: auth_code = input("Please enter your authentication code: ") await page.get_by_placeholder("6-digit authentication code/8-digit backup code").fill(auth_code) await page.get_by_role("button", name="Log In").click() async def post_prompt(page: Page, prompt: str) -> None: """ Post a prompt message in Discord via a Playwright browser page. Args: page (Page): Playwright browser page instance. prompt (str): The prompt to be posted in the message box. Raises: TimeoutError: If any of the page actions do not complete within the default timeout period. """ message_text_boy = page.get_by_role("textbox", name="Message #general").nth(0) await message_text_boy.fill("/imagine ") prompt_input = page.locator(".optionPillValue-2uxsMp").nth(0) await prompt_input.fill(prompt, timeout=2000) await message_text_boy.press("Enter", timeout=2000) async def upscale_image(page: Page) -> None: """ Upscale an image on a Discord channel using the U1 button. Args: page (Page): Playwright browser page instance. Raises: TimeoutError: If any of the page actions do not complete within the default timeout period. """ last_message = page.locator(selector="li").last upscale_1 = last_message.locator("button", has_text="U1") # Wait for the upscale button to be visible while not await upscale_1.is_visible(): print("Upscale button is not yet available, waiting...") await asyncio.sleep(5) # wait for 5 seconds print("Upscale button is now available, clicking...") await upscale_1.click(timeout=1000) async def get_image_url( page: Page, timeout: int = 1000, check_interval: int = 5, max_wait: int = 30 ) -> str: """ Get the href attribute of the last image link on the page, retrying until it exists and the 'Vary (Strong)' button is visible. Args: page (Page): Playwright browser page instance. timeout (int): Maximum time, in milliseconds, to wait for the image link. Default is 1000 milliseconds. check_interval (int): Time, in seconds, to wait between checks for the button and image link. Default is 5 seconds. max_wait (int): Maximum time, in seconds, to wait before giving up. Default is 30 seconds. Returns: str: The href attribute of the last image link. Raises: TimeoutError: If the image link does not appear within the maximum wait time. """ last_message = page.locator(selector="li").last vary_strong = last_message.locator("button", has_text="Vary (Strong)") image_links = last_message.locator("xpath=//a[starts-with(@class, 'originalLink-')]") start_time = time.time() # Wait for the 'Vary (Strong)' button and an image link to appear while True: if await vary_strong.is_visible() and await image_links.count() > 0: last_image_link = await image_links.last.get_attribute("href", timeout=timeout) print("Image link is present, returning it.") return last_image_link print("Waiting for 'Vary (Strong)' button to appear and for image link to appear...") # If the maximum wait time has been reached, raise an exception if time.time() - start_time > max_wait: raise TimeoutError( "Waited for 30 seconds but 'Vary (Strong)' button did not appear and image link did not appear." ) await asyncio.sleep(check_interval) # wait for 5 seconds def update_db_record( engine: Engine, s3_path: str, keyword_value: str, max_retries: int = 5, retry_wait: int = 2 ) -> None: """ Update a database record's blog_post_cover_image_url field with an S3 URL. Args: engine (Engine): SQLAlchemy Engine instance. s3_path (str): S3 URL to be added to the blog_post_cover_image_url field. keyword_value (str): Keyword value to identify the specific record to be updated. max_retries (int): Maximum number of retries in case of failure. Default is 5. retry_wait (int): Time, in seconds, to wait between retries. Default is 2 seconds. Raises: SQLAlchemyError: If any SQLAlchemy error occurs while updating the record. """ retries = 0 while retries < max_retries: try: with engine.connect() as connection: query = text( "UPDATE keywords SET blog_post_cover_image_url = :s3_path WHERE slug = :keyword_value" ) connection.execute(query, s3_path=s3_path, keyword_value=keyword_value) break # break the loop if the operation is successful except OperationalError: retries += 1 print(f"OperationalError occurred. Retry {retries} of {max_retries}.") time.sleep(retry_wait) else: # If we've exhausted all retries, re-raise the last exception raise def get_records_with_null_cover_image(engine: Engine) -> List[Dict[str, str]]: """ Retrieve records from the database where blog_post_cover_image_url is NULL. Args: engine (Engine): SQLAlchemy Engine instance. Returns: List[Dict[str, str]]: A list of dictionaries where each dictionary represents a record with 'slug' and 'blog_post_cover_prompt' as keys. Raises: SQLAlchemyError: If any SQLAlchemy error occurs while retrieving the records. """ with engine.connect() as connection: query = text( "SELECT slug, blog_post_cover_prompt FROM keywords WHERE blog_post_cover_image_url IS NULL" ) result = connection.execute(query) records = [{"slug": row[0], "blog_post_cover_prompt": row[1]} for row in result] return records S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME") S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID") S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY") S3_REGION_NAME = os.environ.get("S3_REGION_NAME") DATABASE_URL = os.environ.get("DATABASE_URL") DISCORD_SERVER_ID = "1124815914815201481" DISCORD_CHANEL_ID = "1124815915297542217" IMAGE_PATH = Path(__file__).parent / "temp_images" async def main() -> None: async with async_playwright() as playwright: # playwright = await async_playwright().start() engine = create_engine(DATABASE_URL) browser = await playwright.chromium.launch(headless=False) context = await browser.new_context() page = await context.new_page() records = get_records_with_null_cover_image(engine) await login_to_discord( page=page, server_id=DISCORD_SERVER_ID, channel_id=DISCORD_CHANEL_ID, ) for record in records[181:]: slug = record["slug"] prompt = record["blog_post_cover_prompt"] await post_prompt( page=page, prompt=prompt, ) await upscale_image(page=page) image_url = await get_image_url(page=page) local_image_path = IMAGE_PATH / f"{slug}.png" image_path = download_image(image_url=image_url, image_path=local_image_path) s3_path = upload_to_s3( image_path=image_path, aws_access_key_id=S3_ACCESS_KEY_ID, aws_secret_access_key=S3_SECRET_ACCESS_KEY, bucket=S3_BUCKET_NAME, region_name=S3_REGION_NAME, s3_image_name=f"{slug}.png", ) update_db_record( engine=engine, s3_path=s3_path, keyword_value=slug, ) await context.close() await browser.close() asyncio.run(main())