Commit 2107e3b4 authored by duanjinfei's avatar duanjinfei

update predict

parent bd337fb8
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
from typing import List from typing import List
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from util import save_images, read_images_in_folder from util import save_images, read_img_result
img_save_folder = "SaveImages" img_save_folder = "SaveImages"
params = { params = {
...@@ -33,7 +33,7 @@ class Predictor(BasePredictor): ...@@ -33,7 +33,7 @@ class Predictor(BasePredictor):
description="", default='https://replicate.delivery/pbxt/0eNDe5B73njlkkSJfJeZ4Riiww3OYci679bnbaf0cNzLh2tRC/image_1.png'), description="", default='https://replicate.delivery/pbxt/0eNDe5B73njlkkSJfJeZ4Riiww3OYci679bnbaf0cNzLh2tRC/image_1.png'),
use_fp32: bool = Input(description="", default=False), use_fp32: bool = Input(description="", default=False),
no_translator: bool = Input(description="", default=False), no_translator: bool = Input(description="", default=False),
) -> List[Path]: ) -> List[str]:
input_data = {} input_data = {}
if mode == "text-generation": if mode == "text-generation":
input_data = { input_data = {
...@@ -58,8 +58,9 @@ class Predictor(BasePredictor): ...@@ -58,8 +58,9 @@ class Predictor(BasePredictor):
if rtn_warning: if rtn_warning:
print(rtn_warning) print(rtn_warning)
return [] return []
res = []
if rtn_code >= 0: if rtn_code >= 0:
save_images(results, img_save_folder) res = read_img_result(results)
# save_images(results, img_save_folder)
print(f'Done, result images are saved in: {img_save_folder}') print(f'Done, result images are saved in: {img_save_folder}')
images = read_images_in_folder(img_save_folder) return res
return images
...@@ -42,6 +42,16 @@ def save_images(img_list, folder): ...@@ -42,6 +42,16 @@ def save_images(img_list, folder):
cv2.imwrite(save_path, img[..., ::-1]) cv2.imwrite(save_path, img[..., ::-1])
def read_img_result(img_list):
res = []
for _, img in enumerate(img_list):
_, img_encoded = cv2.imencode('.jpg', img)
# Encode image as base64
base64_encoded = base64.b64encode(img_encoded).decode('utf-8')
res.append(base64_encoded)
return res
def check_channels(image): def check_channels(image):
channels = image.shape[2] if len(image.shape) == 3 else 1 channels = image.shape[2] if len(image.shape) == 3 else 1
if channels == 1: if channels == 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment