From 70d6ab0bf5c68d44f1a62dc3e4e744f6a31fe0e3 Mon Sep 17 00:00:00 2001 From: "Jie.F" Date: Fri, 23 Aug 2024 18:58:13 +0800 Subject: [PATCH] Update stable_diffusion.py (#7536) --- .../stablediffusion/tools/stable_diffusion.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 0c5ebc23a..4be9207d6 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -27,7 +27,7 @@ DRAW_TEXT_OPTIONS = { "seed_resize_from_w": -1, # Samplers - # "sampler_name": "DPM++ 2M", + "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", @@ -178,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): return [d['model_name'] for d in response.json()] except Exception as e: return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return [] + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'samplers') + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d['name'] for d in response.json()] + except Exception as e: + return [] def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -339,7 +356,27 @@ class StableDiffusionTool(BuiltinTool): label=I18nObject(en_US=i, zh_Hans=i) ) for i in models]) ) + except: pass - + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter(name='sampler_name', + label=I18nObject(en_US='Sampling method', zh_Hans='Sampling method'), + human_description=I18nObject( + en_US='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + zh_Hans='Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description='Sampling method of Stable Diffusion, you can check the official documentation of Stable Diffusion', + required=True, + default=sample_methods[0], + options=[ToolParameterOption( + value=i, + label=I18nObject(en_US=i, zh_Hans=i) + ) for i in sample_methods]) + ) return parameters