Commit 2107e3b4 authored by duanjinfei's avatar duanjinfei

update predict

parent bd337fb8
......@@ -4,7 +4,7 @@
from cog import BasePredictor, Input, Path
from typing import List
from modelscope.pipelines import pipeline
from util import save_images, read_images_in_folder
from util import save_images, read_img_result
img_save_folder = "SaveImages"
params = {
......@@ -33,7 +33,7 @@ class Predictor(BasePredictor):
description="", default='https://replicate.delivery/pbxt/0eNDe5B73njlkkSJfJeZ4Riiww3OYci679bnbaf0cNzLh2tRC/image_1.png'),
use_fp32: bool = Input(description="", default=False),
no_translator: bool = Input(description="", default=False),
) -> List[Path]:
) -> List[str]:
input_data = {}
if mode == "text-generation":
input_data = {
......@@ -58,8 +58,9 @@ class Predictor(BasePredictor):
if rtn_warning:
print(rtn_warning)
return []
res = []
if rtn_code >= 0:
save_images(results, img_save_folder)
res = read_img_result(results)
# save_images(results, img_save_folder)
print(f'Done, result images are saved in: {img_save_folder}')
images = read_images_in_folder(img_save_folder)
return images
return res
......@@ -42,6 +42,16 @@ def save_images(img_list, folder):
cv2.imwrite(save_path, img[..., ::-1])
def read_img_result(img_list):
res = []
for _, img in enumerate(img_list):
_, img_encoded = cv2.imencode('.jpg', img)
# Encode image as base64
base64_encoded = base64.b64encode(img_encoded).decode('utf-8')
res.append(base64_encoded)
return res
def check_channels(image):
channels = image.shape[2] if len(image.shape) == 3 else 1
if channels == 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment