Skip to content

Instantly share code, notes, and snippets.

@feliche93
Created July 3, 2023 08:44
Show Gist options
  • Select an option

  • Save feliche93/0c928a9ca2ee8bc9b907173a007b3868 to your computer and use it in GitHub Desktop.

Select an option

Save feliche93/0c928a9ca2ee8bc9b907173a007b3868 to your computer and use it in GitHub Desktop.

Revisions

  1. feliche93 created this gist Jul 3, 2023.
    336 changes: 336 additions & 0 deletions discord_midjourney_automation.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,336 @@
    import asyncio
    import os
    from getpass import getpass
    from pathlib import Path
    from typing import Dict, List, Optional

    import boto3
    import requests
    from dotenv import load_dotenv
    from playwright.async_api import Page, async_playwright
    from sqlalchemy import create_engine, text
    from sqlalchemy.engine.base import Engine
    import time
    from sqlalchemy.exc import OperationalError
    from sqlalchemy.exc import OperationalError


    load_dotenv(override=True)


    def download_image(image_url: str, image_path: str, timeout: int = 5) -> str:
    """
    Downloads an image from a provided URL and saves it to a local path.
    Args:
    image_url (str): URL of the image to download.
    image_path (str): Local path where the image will be saved, including the image file name.
    timeout (int): Maximum time, in seconds, to wait for the server's response. Default is 5 seconds.
    Raises:
    HTTPError: If there was an unsuccessful HTTP response.
    Timeout: If the request times out.
    Returns:
    str: Local path where the image has been saved.
    """
    response = requests.get(image_url, timeout=timeout)
    response.raise_for_status() # Raise exception if invalid response.
    with open(image_path, "wb") as f:
    f.write(response.content)
    return image_path


    def upload_to_s3(
    image_path: str,
    bucket: str,
    s3_image_name: str,
    aws_access_key_id: str,
    aws_secret_access_key: str,
    region_name: str,
    ) -> str:
    """
    Uploads an image file to an S3 bucket and returns the URL of the uploaded file.
    Args:
    image_path (str): Path to the image file to upload.
    bucket (str): Name of the S3 bucket to upload to.
    s3_image_name (str): Name to give to the file once it's uploaded.
    aws_access_key_id (str): AWS access key ID.
    aws_secret_access_key (str): AWS secret access key.
    region_name (str): The name of the AWS region where the S3 bucket is located.
    Returns:
    str: URL of the uploaded image in the S3 bucket.
    Raises:
    ClientError: If there was an error uploading the file to S3.
    """
    s3 = boto3.client(
    "s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
    )
    with open(image_path, "rb") as f:
    s3_path = "blog_post_covers/" + s3_image_name # prepend the S3 'folder' name
    s3.upload_fileobj(f, bucket, s3_path)

    # remove the image from the local filesystem
    os.remove(image_path)

    url = f"https://{bucket}.s3.{region_name}.amazonaws.com/{s3_path}"
    return url


    async def login_to_discord(
    page: Page,
    server_id: str,
    channel_id: str,
    email: Optional[str] = None,
    password: Optional[str] = None,
    auth_code: Optional[str] = None,
    ) -> None:
    """
    Log in to Discord via a Playwright browser page.
    Args:
    page (Page): Playwright browser page instance.
    server_id (str): Discord server ID to navigate to after login.
    channel_id (str): Discord channel ID to navigate to after login.
    email (Optional[str], optional): Email to use for logging in to Discord. Defaults to None.
    password (Optional[str], optional): Password to use for logging in to Discord. Defaults to None.
    auth_code (Optional[str], optional): Authentication code to use for logging in to Discord. Defaults to None.
    Raises:
    TimeoutError: If any of the page actions do not complete within the default timeout period.
    """
    discord_channel_url = f"https://discord.com/channels/{server_id}/{channel_id}"
    await page.goto(discord_channel_url)

    await page.get_by_role("button", name="Continue in browser").click()
    await page.get_by_label("Email or Phone Number*").click()

    if not email:
    email = input("Please enter your email: ")
    await page.get_by_label("Email or Phone Number*").fill(email)

    await page.get_by_label("Email or Phone Number*").press("Tab")

    if not password:
    password = getpass("Please enter your password: ")
    await page.get_by_label("Password*").fill(password)

    await page.get_by_role("button", name="Log In").click()

    if not auth_code:
    auth_code = input("Please enter your authentication code: ")
    await page.get_by_placeholder("6-digit authentication code/8-digit backup code").fill(auth_code)

    await page.get_by_role("button", name="Log In").click()


    async def post_prompt(page: Page, prompt: str) -> None:
    """
    Post a prompt message in Discord via a Playwright browser page.
    Args:
    page (Page): Playwright browser page instance.
    prompt (str): The prompt to be posted in the message box.
    Raises:
    TimeoutError: If any of the page actions do not complete within the default timeout period.
    """
    message_text_boy = page.get_by_role("textbox", name="Message #general").nth(0)
    await message_text_boy.fill("/imagine ")

    prompt_input = page.locator(".optionPillValue-2uxsMp").nth(0)

    await prompt_input.fill(prompt, timeout=2000)
    await message_text_boy.press("Enter", timeout=2000)


    async def upscale_image(page: Page) -> None:
    """
    Upscale an image on a Discord channel using the U1 button.
    Args:
    page (Page): Playwright browser page instance.
    Raises:
    TimeoutError: If any of the page actions do not complete within the default timeout period.
    """
    last_message = page.locator(selector="li").last
    upscale_1 = last_message.locator("button", has_text="U1")

    # Wait for the upscale button to be visible
    while not await upscale_1.is_visible():
    print("Upscale button is not yet available, waiting...")
    await asyncio.sleep(5) # wait for 5 seconds

    print("Upscale button is now available, clicking...")
    await upscale_1.click(timeout=1000)


    async def get_image_url(
    page: Page, timeout: int = 1000, check_interval: int = 5, max_wait: int = 30
    ) -> str:
    """
    Get the href attribute of the last image link on the page, retrying until it exists and the 'Vary (Strong)' button is visible.
    Args:
    page (Page): Playwright browser page instance.
    timeout (int): Maximum time, in milliseconds, to wait for the image link. Default is 1000 milliseconds.
    check_interval (int): Time, in seconds, to wait between checks for the button and image link. Default is 5 seconds.
    max_wait (int): Maximum time, in seconds, to wait before giving up. Default is 30 seconds.
    Returns:
    str: The href attribute of the last image link.
    Raises:
    TimeoutError: If the image link does not appear within the maximum wait time.
    """

    last_message = page.locator(selector="li").last
    vary_strong = last_message.locator("button", has_text="Vary (Strong)")
    image_links = last_message.locator("xpath=//a[starts-with(@class, 'originalLink-')]")

    start_time = time.time()

    # Wait for the 'Vary (Strong)' button and an image link to appear
    while True:
    if await vary_strong.is_visible() and await image_links.count() > 0:
    last_image_link = await image_links.last.get_attribute("href", timeout=timeout)
    print("Image link is present, returning it.")
    return last_image_link

    print("Waiting for 'Vary (Strong)' button to appear and for image link to appear...")

    # If the maximum wait time has been reached, raise an exception
    if time.time() - start_time > max_wait:
    raise TimeoutError(
    "Waited for 30 seconds but 'Vary (Strong)' button did not appear and image link did not appear."
    )

    await asyncio.sleep(check_interval) # wait for 5 seconds


    def update_db_record(
    engine: Engine, s3_path: str, keyword_value: str, max_retries: int = 5, retry_wait: int = 2
    ) -> None:
    """
    Update a database record's blog_post_cover_image_url field with an S3 URL.
    Args:
    engine (Engine): SQLAlchemy Engine instance.
    s3_path (str): S3 URL to be added to the blog_post_cover_image_url field.
    keyword_value (str): Keyword value to identify the specific record to be updated.
    max_retries (int): Maximum number of retries in case of failure. Default is 5.
    retry_wait (int): Time, in seconds, to wait between retries. Default is 2 seconds.
    Raises:
    SQLAlchemyError: If any SQLAlchemy error occurs while updating the record.
    """
    retries = 0
    while retries < max_retries:
    try:
    with engine.connect() as connection:
    query = text(
    "UPDATE keywords SET blog_post_cover_image_url = :s3_path WHERE slug = :keyword_value"
    )
    connection.execute(query, s3_path=s3_path, keyword_value=keyword_value)
    break # break the loop if the operation is successful
    except OperationalError:
    retries += 1
    print(f"OperationalError occurred. Retry {retries} of {max_retries}.")
    time.sleep(retry_wait)
    else: # If we've exhausted all retries, re-raise the last exception
    raise


    def get_records_with_null_cover_image(engine: Engine) -> List[Dict[str, str]]:
    """
    Retrieve records from the database where blog_post_cover_image_url is NULL.
    Args:
    engine (Engine): SQLAlchemy Engine instance.
    Returns:
    List[Dict[str, str]]: A list of dictionaries where each dictionary represents a record
    with 'slug' and 'blog_post_cover_prompt' as keys.
    Raises:
    SQLAlchemyError: If any SQLAlchemy error occurs while retrieving the records.
    """
    with engine.connect() as connection:
    query = text(
    "SELECT slug, blog_post_cover_prompt FROM keywords WHERE blog_post_cover_image_url IS NULL"
    )
    result = connection.execute(query)
    records = [{"slug": row[0], "blog_post_cover_prompt": row[1]} for row in result]
    return records


    S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME")
    S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID")
    S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY")
    S3_REGION_NAME = os.environ.get("S3_REGION_NAME")
    DATABASE_URL = os.environ.get("DATABASE_URL")
    DISCORD_SERVER_ID = "1124815914815201481"
    DISCORD_CHANEL_ID = "1124815915297542217"

    IMAGE_PATH = Path(__file__).parent / "temp_images"


    async def main() -> None:
    async with async_playwright() as playwright:
    # playwright = await async_playwright().start()

    engine = create_engine(DATABASE_URL)

    browser = await playwright.chromium.launch(headless=False)
    context = await browser.new_context()
    page = await context.new_page()

    records = get_records_with_null_cover_image(engine)

    await login_to_discord(
    page=page,
    server_id=DISCORD_SERVER_ID,
    channel_id=DISCORD_CHANEL_ID,
    )

    for record in records[181:]:
    slug = record["slug"]
    prompt = record["blog_post_cover_prompt"]

    await post_prompt(
    page=page,
    prompt=prompt,
    )

    await upscale_image(page=page)

    image_url = await get_image_url(page=page)

    local_image_path = IMAGE_PATH / f"{slug}.png"

    image_path = download_image(image_url=image_url, image_path=local_image_path)

    s3_path = upload_to_s3(
    image_path=image_path,
    aws_access_key_id=S3_ACCESS_KEY_ID,
    aws_secret_access_key=S3_SECRET_ACCESS_KEY,
    bucket=S3_BUCKET_NAME,
    region_name=S3_REGION_NAME,
    s3_image_name=f"{slug}.png",
    )

    update_db_record(
    engine=engine,
    s3_path=s3_path,
    keyword_value=slug,
    )

    await context.close()
    await browser.close()


    asyncio.run(main())