Update stable_diffusion.py (#7536)

This commit is contained in:
Jie.F 2024-08-23 18:58:13 +08:00 committed by GitHub
parent e42848f4b7
commit 70d6ab0bf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = {
"seed_resize_from_w": -1,
# Samplers
# "sampler_name": "DPM++ 2M",
"sampler_name": "DPM++ 2M",
# "scheduler": "",
# "sampler_index": "Automatic",
@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool):
return [d['model_name'] for d in response.json()]
except Exception as e:
return []
def get_sample_methods(self) -> list[str]:
"""
get sample method
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
return []
api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers')
response = get(url=api_url, timeout=(2, 10))
if response.status_code != 200:
return []
else:
return [d['name'] for d in response.json()]
except Exception as e:
return []
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool):
label=I18nObject(en_US=i, zh_Hans=i)
) for i in models])
)
except:
pass
sample_methods = self.get_sample_methods()
if len(sample_methods) != 0:
parameters.append(
ToolParameter(name='sampler_name',
label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'),
human_description=I18nObject(
en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
zh_Hans='Stable Diffusion 的Sampling method您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion',
required=True,
default=sample_methods[0],
options=[ToolParameterOption(
value=i,
label=I18nObject(en_US=i, zh_Hans=i)
) for i in sample_methods])
)
return parameters