Commit bd337fb8 authored by duanjinfei's avatar duanjinfei

update predict

parent 1bbbfb45
......@@ -4,7 +4,7 @@
from cog import BasePredictor, Input, Path
from typing import List
from modelscope.pipelines import pipeline
# from util import save_images
from util import save_images, read_images_in_folder
img_save_folder = "SaveImages"
params = {
......@@ -28,34 +28,38 @@ class Predictor(BasePredictor):
seed: int = Input(description="", default=66273235,
ge=0, le=66273235),
draw_pos: Path = Input(
description="", default='example_images/gen9.png'),
description="", default='https://replicate.delivery/pbxt/WwkQofrF9CQuJ6v3slDfyha3YB1teWJpnWLBsVSdvfeFh2tRC/image_0.png'),
ori_image: Path = Input(
description="", default='example_images/ref7.jpg'),
description="", default='https://replicate.delivery/pbxt/0eNDe5B73njlkkSJfJeZ4Riiww3OYci679bnbaf0cNzLh2tRC/image_1.png'),
use_fp32: bool = Input(description="", default=False),
no_translator: bool = Input(description="", default=False),
font_path: str = Input(
description="", default='font/Arial_Unicode.ttf'),
model_path: str = Input(description="", default=None)
) -> List[Path]:
input_data = {
"prompt": prompt,
"seed": seed,
"draw_pos": draw_pos,
"ori_image": ori_image,
"use_fp32": use_fp32,
"no_translator": no_translator,
"font_path": font_path,
"model_path": model_path
}
input_data = {}
if mode == "text-generation":
input_data = {
"prompt": prompt,
"seed": seed,
"draw_pos": str(draw_pos),
"use_fp32": use_fp32,
"no_translator": no_translator,
}
else:
input_data = {
"prompt": prompt,
"seed": seed,
"draw_pos": str(draw_pos),
"ori_image": str(ori_image),
"use_fp32": use_fp32,
"no_translator": no_translator,
}
print("input-data:", input_data)
results, rtn_code, rtn_warning, debug_info = self.model(
input_data, mode=mode, **params)
if rtn_warning:
print(rtn_warning)
return []
files = []
if rtn_code >= 0:
# save_images(results, img_save_folder)
save_images(results, img_save_folder)
print(f'Done, result images are saved in: {img_save_folder}')
for file in results:
files.append(file)
return files
images = read_images_in_folder(img_save_folder)
return images
import datetime
import os
import cv2
import base64
def read_image_to_base64(file_path):
with open(file_path, "rb") as img_file:
img_data = img_file.read()
base64_encoded = base64.b64encode(img_data).decode('utf-8')
return base64_encoded
def read_images_in_folder(folder_path):
image_files = [f for f in os.listdir(
folder_path) if os.path.isfile(os.path.join(folder_path, f))]
base64_images = []
for file in image_files:
file_path = os.path.join(folder_path, file)
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
base64_images.append({
"filename": file,
"base64": read_image_to_base64(file_path)
})
return base64_images
def save_images(img_list, folder):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment