fix: Unable to display images generated by Dall-E 3 (#6155)

This commit is contained in:
Weishan-0 2024-07-18 19:37:04 +08:00 committed by GitHub
parent 4a026fa352
commit 7b45a5d452
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,5 @@
import base64
import random
from base64 import b64decode
from typing import Any, Union
from openai import OpenAI
@ -69,11 +69,50 @@ class DallE3Tool(BuiltinTool):
result = []
for image in response.data:
result.append(self.create_blob_message(blob=b64decode(image.b64_json),
meta={'mime_type': 'image/png'},
save_as=self.VARIABLE_KEY.IMAGE.value))
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message(blob=blob_image,
meta={'mime_type': mime_type},
save_as=self.VARIABLE_KEY.IMAGE.value)
result.append(blob_message)
return result
@staticmethod
def _decode_image(base64_image: str) -> tuple[str, bytes]:
"""
Decode a base64 encoded image. If the image is not prefixed with a MIME type,
it assumes 'image/png' as the default.
:param base64_image: Base64 encoded image string
:return: A tuple containing the MIME type and the decoded image bytes
"""
if DallE3Tool._is_plain_base64(base64_image):
return 'image/png', base64.b64decode(base64_image)
else:
return DallE3Tool._extract_mime_and_data(base64_image)
@staticmethod
def _is_plain_base64(encoded_str: str) -> bool:
"""
Check if the given encoded string is plain base64 without a MIME type prefix.
:param encoded_str: Base64 encoded image string
:return: True if the string is plain base64, False otherwise
"""
return not encoded_str.startswith('data:image')
@staticmethod
def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]:
"""
Extract MIME type and image data from a base64 encoded string with a MIME type prefix.
:param encoded_str: Base64 encoded image string with MIME type prefix
:return: A tuple containing the MIME type and the decoded image bytes
"""
mime_type = encoded_str.split(';')[0].split(':')[1]
image_data_base64 = encoded_str.split(',')[1]
decoded_data = base64.b64decode(image_data_base64)
return mime_type, decoded_data
@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'