-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlambda_function.py
139 lines (123 loc) · 5.69 KB
/
lambda_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import base64
import io
import json
import logging
import boto3
from PIL import Image
from botocore.exceptions import ClientError
class ImageError(Exception):
"Custom exception for errors returned by SDXL"
def __init__(self, message):
self.message = message
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def generate_image(model_id, body):
"""
Generate an image using SDXL 1.0 on demand.
Args:
model_id (str): The model ID to use.
body (str): The request body to use.
Returns:
image_bytes (bytes): The image generated by the model.
"""
logger.info("Generating image with SDXL model %s", model_id)
bedrock = boto3.client(service_name='bedrock-runtime')
accept = "application/json"
content_type = "application/json"
response = bedrock.invoke_model(
body=body, modelId=model_id, accept=accept, contentType=content_type
)
response_body = json.loads(response.get("body").read())
logger.info(response_body.get('result'))
base64_image = response_body.get("artifacts")[0].get("base64")
base64_bytes = base64_image.encode('ascii')
image_bytes = base64.b64decode(base64_bytes)
finish_reason = response_body.get("artifacts")[0].get("finishReason")
if finish_reason == 'ERROR' or finish_reason == 'CONTENT_FILTERED':
raise ImageError(f"Image generation error. Error code is {finish_reason}")
logger.info("Successfully generated image with the SDXL 1.0 model %s", model_id)
return image_bytes
def get_image_from_s3(bucket_name, object_key):
"""
Retrieve an image from an S3 bucket.
Args:
bucket_name (str): The S3 bucket name.
object_key (str): The S3 object key (file path).
Returns:
image_data (str): The image encoded as a base64 string.
"""
s3 = boto3.client('s3')
try:
response = s3.get_object(Bucket=bucket_name, Key=object_key)
# image_data = response['Body'].read()
# logger.info("Successfully retrieved image from S3 bucket: %s, key: %s", bucket_name, object_key)
# return base64.b64encode(image_data).decode('utf8')
image_data = response['Body'].read()
# Open the image using PIL
img = Image.open(io.BytesIO(image_data))
# Calculate new dimensions (multiples of 64)
width, height = img.size
new_width = ((width + 63) // 64) * 64
new_height = ((height + 63) // 64) * 64
# Resize the image
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
# Convert back to bytes
buffer = io.BytesIO()
img_resized.save(buffer)
#img_resized.save(buffer, format="JPEG")
resized_image_data = buffer.getvalue()
logger.info(f"Resized image from {width}x{height} to {new_width}x{new_height}")
return base64.b64encode(resized_image_data).decode('utf8')
except ClientError as err:
logger.error("Failed to retrieve image from S3: %s", err)
raise
def lambda_handler(event, context):
"""
Lambda handler function for triggering on S3 upload events.
"""
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
model_id = 'stability.stable-diffusion-xl-v1'
prompt = """Analyze the input image to detect the face and focus on the full headshot, including the top of the head and shoulders. Adjust the crop area dynamically based on the detected face to include additional padding above the head (e.g., 20% of the face height) and around the sides for a balanced and professional appearance. Generate a circular image with transparent padding to ensure the face and head are fully visible within a 150x150 pixel and a 175x175 pixel circular frame. Ensure the circular area is centered, and apply a subtle feathered edge to blend the circular
content smoothly with the transparent background."""
# Extract bucket name and object key from the S3 event
try:
bucket_name = event['Records'][0]['s3']['bucket']['name']
object_key = event['Records'][0]['s3']['object']['key']
logger.info("Triggered by S3 bucket: %s, key: %s", bucket_name, object_key)
# Ensure the object key starts with the desired prefix
if not object_key.startswith("input/"):
logger.warning("Skipping object with key: %s (not in 'input/' prefix)", object_key)
return
# Read the image from S3
init_image = get_image_from_s3(bucket_name, object_key)
# Create request body
body = json.dumps({
"text_prompts": [
{
"text": prompt
}
],
"init_image": init_image,
"style_preset": "isometric"
})
# Generate the image
image_bytes = generate_image(model_id=model_id, body=body)
# Save the generated image back to S3
output_key = object_key.replace("input/", "output/").replace(".jpg", "_processed.jpg")
s3 = boto3.client('s3')
s3.put_object(Bucket=bucket_name, Key=output_key, Body=image_bytes, ContentType='image/jpeg')
logger.info("Successfully saved processed image to S3 bucket: %s, key: %s", bucket_name, output_key)
except KeyError as err:
logger.error("Failed to extract required information from event: %s", err)
raise
except ClientError as err:
message = err.response["Error"]["Message"]
logger.error("A client error occurred: %s", message)
raise
except ImageError as err:
logger.error(err.message)
raise
return {
"statusCode": 200,
"body": json.dumps(f"Finished processing image from {object_key}")
}