# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path
from typing import List
from modelscope.pipelines import pipeline
# from util import save_images

img_save_folder = "SaveImages"
params = {
    "show_debug": True,
    "image_count": 2,
    "ddim_steps": 20,
}


class Predictor(BasePredictor):
    def setup(self) -> None:
        self.model = pipeline('my-anytext-task',
                              model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.3')

    def predict(
        self,
        mode: str = Input(description="Select model type", default="text-generation", choices=[
            "text-generation", "text-editing"]),
        prompt: str = Input(description="Input prompt",
                            default='photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'),
        seed: int = Input(description="", default=66273235,
                          ge=0, le=66273235),
        draw_pos: Path = Input(
            description="", default='example_images/gen9.png'),
        ori_image: Path = Input(
            description="", default='example_images/ref7.jpg'),
        use_fp32: bool = Input(description="", default=False),
        no_translator: bool = Input(description="", default=False),
        font_path: str = Input(
            description="", default='font/Arial_Unicode.ttf'),
        model_path: str = Input(description="", default=None)
    ) -> List[Path]:
        input_data = {
            "prompt": prompt,
            "seed": seed,
            "draw_pos": draw_pos,
            "ori_image": ori_image,
            "use_fp32": use_fp32,
            "no_translator": no_translator,
            "font_path": font_path,
            "model_path": model_path
        }
        results, rtn_code, rtn_warning, debug_info = self.model(
            input_data, mode=mode, **params)
        if rtn_warning:
            print(rtn_warning)
            return []
        files = []
        if rtn_code >= 0:
            # save_images(results, img_save_folder)
            print(f'Done, result images are saved in: {img_save_folder}')
            for file in results:
                files.append(file)
        return files
