diff --git a/.gitignore b/.gitignore index 49812a19..733b2f2c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ venv .env data -models \ No newline at end of file +models +simulated_uploaded +__pycache__ \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/rp_handler.py b/src/rp_handler.py index c4653782..144b1175 100644 --- a/src/rp_handler.py +++ b/src/rp_handler.py @@ -18,9 +18,6 @@ COMFY_POLLING_MAX_RETRIES = 100 # Host where ComfyUI is running COMFY_HOST = "127.0.0.1:8188" -# The path where ComfyUI stores the generated images -COMFY_OUTPUT_PATH = "/comfyui/output" - def check_server(url, retries=50, delay=500): """ @@ -38,6 +35,7 @@ def check_server(url, retries=50, delay=500): for i in range(retries): try: response = requests.get(url) + # If the response status code is 200, the server is up and running if response.status_code == 200: print(f"runpod-worker-comfy - API is reachable") @@ -87,10 +85,82 @@ def get_history(prompt_id): def base64_encode(img_path): """ Returns base64 encoded image. + + Args: + img_path (str): The path to the image + + Returns: + str: The base64 encoded image """ with open(img_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()) - return encoded_string.decode("utf-8") + encoded_string = base64.b64encode(image_file.read()).decode("utf-8") + return f"data:image/png;base64,{encoded_string}" + +def process_output_images(outputs, job_id): + """ + This function takes the "outputs" from image generation and the job ID, + then determines the correct way to return the image, either as a direct URL + to an AWS S3 bucket or as a base64 encoded string, depending on the + environment configuration. + + Args: + outputs (dict): A dictionary containing the outputs from image generation, + typically includes node IDs and their respective output data. + job_id (str): The unique identifier for the job. + + Returns: + dict: A dictionary with the status ('success' or 'error') and the message, + which is either the URL to the image in the AWS S3 bucket or a base64 + encoded string of the image. In case of error, the message details the issue. + + The function works as follows: + - It first determines the output path for the images from an environment variable, + defaulting to "/comfyui/output" if not set. + - It then iterates through the outputs to find the filenames of the generated images. + - After confirming the existence of the image in the output folder, it checks if the + AWS S3 bucket is configured via the BUCKET_ENDPOINT_URL environment variable. + - If AWS S3 is configured, it uploads the image to the bucket and returns the URL. + - If AWS S3 is not configured, it encodes the image in base64 and returns the string. + - If the image file does not exist in the output folder, it returns an error status + with a message indicating the missing image file. + """ + + # The path where ComfyUI stores the generated images + COMFY_OUTPUT_PATH = os.environ.get('COMFY_OUTPUT_PATH', "/comfyui/output") + + output_images = {} + + for node_id, node_output in outputs.items(): + if "images" in node_output: + for image in node_output["images"]: + output_images = image["filename"] + + print(f"runpod-worker-comfy - image generation is done") + + # expected image output folder + local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" + + # The image is in the output folder + if os.path.exists(local_image_path): + print("runpod-worker-comfy - the image exists in the output folder") + + if os.environ.get('BUCKET_ENDPOINT_URL', False): + # URL to image in AWS S3 + image = rp_upload.upload_image(job_id, local_image_path) + else: + # base64 image + image = base64_encode(local_image_path) + + return { + "status": "success", + "message": image, + } + else: + print("runpod-worker-comfy - the image does not exist in the output folder") + return { + "status": "error", + "message": f"the image does not exist in the specified output folder: {local_image_path}", + } def handler(job): @@ -158,37 +228,10 @@ def handler(job): except Exception as e: return {"error": f"Error waiting for image generation: {str(e)}"} - # Fetching generated images - output_images = {} - - outputs = history[prompt_id].get("outputs") - - for node_id, node_output in outputs.items(): - if "images" in node_output: - for image in node_output["images"]: - output_images = image["filename"] - - print(f"runpod-worker-comfy - image generation is done") - - # expected image output folder - local_image_path = f"{COMFY_OUTPUT_PATH}/{output_images}" - # The image is in the output folder - if os.path.exists(local_image_path): - print("runpod-worker-comfy - the image exists in the output folder") - image_url = rp_upload.upload_image(job["id"], local_image_path) - return_base64 = "simulated_uploaded/" in image_url - return_output = f"{image_url}" if not return_base64 else base64_encode(local_image_path) - return { - "status": "success", - "message": return_output, - } - else: - print("runpod-worker-comfy - the image does not exist in the output folder") - return { - "status": "error", - "message": f"the image does not exist in the specified output folder: {local_image_path}", - } + # Get the generated image and return it as URL in an AWS bucket or as base64 + process_output_images(history[prompt_id].get("outputs"), job[id]) -# Start the handler -runpod.serverless.start({"handler": handler}) +# Start the handler only if this script is run directly +if __name__ == "__main__": + runpod.serverless.start({"handler": handler}) diff --git a/test_resources/images/ComfyUI_00001_.png b/test_resources/images/ComfyUI_00001_.png new file mode 100644 index 00000000..d6d3ce6c Binary files /dev/null and b/test_resources/images/ComfyUI_00001_.png differ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_rp_handler.py b/tests/test_rp_handler.py new file mode 100644 index 00000000..23e1197f --- /dev/null +++ b/tests/test_rp_handler.py @@ -0,0 +1,130 @@ +import unittest +from unittest.mock import patch, MagicMock, mock_open, Mock +import sys +import os +import json + +# Make sure that "src" is known and can be used to import rp_handler.py +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) +from src import rp_handler + +# Local folder for test resources +RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES = "./test_resources/images" + +class TestRunpodWorkerComfy(unittest.TestCase): + @patch('rp_handler.requests.get') + def test_check_server_server_up(self, mock_requests): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests.return_value = mock_response + + result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50) + self.assertTrue(result) + + @patch('rp_handler.requests.get') + def test_check_server_server_down(self, mock_requests): + mock_requests.get.side_effect = rp_handler.requests.RequestException() + result = rp_handler.check_server('http://127.0.0.1:8188', 1, 50) + self.assertFalse(result) + + @patch('rp_handler.urllib.request.urlopen') + def test_queue_prompt(self, mock_urlopen): + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({"prompt_id": "123"}).encode() + mock_urlopen.return_value = mock_response + result = rp_handler.queue_prompt({"prompt": "test"}) + self.assertEqual(result, {"prompt_id": "123"}) + + @patch('rp_handler.urllib.request.urlopen') + def test_get_history(self, mock_urlopen): + # Mock response data as a JSON string + mock_response_data = json.dumps({"key": "value"}).encode('utf-8') + + # Define a mock response function for `read` + def mock_read(): + return mock_response_data + + # Create a mock response object + mock_response = Mock() + mock_response.read = mock_read + + # Mock the __enter__ and __exit__ methods to support the context manager + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = Mock() + + # Set the return value of the urlopen mock + mock_urlopen.return_value = mock_response + + # Call the function under test + result = rp_handler.get_history("123") + + # Assertions + self.assertEqual(result, {"key": "value"}) + mock_urlopen.assert_called_with("http://127.0.0.1:8188/history/123") + + @patch('builtins.open', new_callable=mock_open, read_data=b'test') + def test_base64_encode(self, mock_file): + result = rp_handler.base64_encode("dummy_path") + self.assertTrue(result.startswith("data:image/png;base64,")) + + @patch('rp_handler.os.path.exists') + @patch('rp_handler.rp_upload.upload_image') + @patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES}) + def test_bucket_endpoint_not_configured(self, mock_upload_image, mock_exists): + mock_exists.return_value = True + mock_upload_image.return_value = 'simulated_uploaded/image.png' + + outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} + job_id = '123' + + result = rp_handler.process_output_images(outputs, job_id) + + self.assertEqual(result['status'], 'success') + self.assertTrue(result['message'].startswith("data:image/png;base64,")) + + @patch('rp_handler.os.path.exists') + @patch('rp_handler.rp_upload.upload_image') + @patch.dict(os.environ, {'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, 'BUCKET_ENDPOINT_URL': 'http://example.com'}) + def test_bucket_endpoint_configured(self, mock_upload_image, mock_exists): + # Mock the os.path.exists to return True, simulating that the image exists + mock_exists.return_value = True + + # Mock the rp_upload.upload_image to return a simulated URL + mock_upload_image.return_value = 'http://example.com/uploaded/image.png' + + # Define the outputs and job_id for the test + outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} + job_id = '123' + + # Call the function under test + result = rp_handler.process_output_images(outputs, job_id) + + # Assertions + self.assertEqual(result['status'], 'success') + self.assertEqual(result['message'], 'http://example.com/uploaded/image.png') + mock_upload_image.assert_called_once_with(job_id, './test_resources/images/ComfyUI_00001_.png') + + + @patch('rp_handler.os.path.exists') + @patch('rp_handler.rp_upload.upload_image') + @patch.dict(os.environ, { + 'COMFY_OUTPUT_PATH': RUNPOD_WORKER_COMFY_TEST_RESOURCES_IMAGES, + 'BUCKET_ENDPOINT_URL': 'http://example.com', + 'BUCKET_ACCESS_KEY_ID': '', + 'BUCKET_SECRET_ACCESS_KEY': '' + }) + def test_bucket_image_upload_fails_env_vars_wrong_or_missing(self, mock_upload_image, mock_exists): + # Simulate the file existing in the output path + mock_exists.return_value = True + + # When AWS credentials are wrong or missing, upload_image should return 'simulated_uploaded/...' + mock_upload_image.return_value = 'simulated_uploaded/image.png' + + outputs = {'node_id': {'images': [{'filename': 'ComfyUI_00001_.png'}]}} + job_id = '123' + + result = rp_handler.process_output_images(outputs, job_id) + + # Check if the image was saved to the 'simulated_uploaded' directory + self.assertIn('simulated_uploaded', result['message']) + self.assertEqual(result['status'], 'success') \ No newline at end of file