Commit 1bbbfb45 authored by duanjinfei's avatar duanjinfei

add Anytext

parent d65ced0b
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from ldm.modules.diffusionmodules.util import conv_nd, linear, zero_module
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
def get_clip_vision_emb(encoder, processor, img):
_img = img.repeat(1, 3, 1, 1)*255
inputs = processor(images=_img, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
outputs = encoder(**inputs)
emb = outputs.image_embeds
return emb
def get_recog_emb(encoder, img_list):
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
encoder.predictor.eval()
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
return preds_neck
def pad_H(x):
_, _, H, W = x.shape
p_top = (W - H) // 2
p_bot = W - H - p_top
return F.pad(x, (0, 0, p_top, p_bot))
class EncodeNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeNet, self).__init__()
chan = 16
n_layer = 4 # downsample
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
self.conv_list = nn.ModuleList([])
_c = chan
for i in range(n_layer):
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
_c *= 2
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.act = nn.SiLU()
def forward(self, x):
x = self.act(self.conv1(x))
for layer in self.conv_list:
x = self.act(layer(x))
x = self.act(self.conv2(x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
valid=True,
glyph_channels=20,
position_channels=1,
placeholder_string='*',
add_pos=False,
emb_type='ocr',
**kwargs
):
super().__init__()
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
token_dim = 768
if hasattr(embedder, 'vit'):
assert emb_type == 'vit'
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
self.get_recog_emb = None
else: # using LDM's BERT encoder
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
token_dim = 1280
self.token_dim = token_dim
self.emb_type = emb_type
self.add_pos = add_pos
if add_pos:
self.position_encoder = EncodeNet(position_channels, token_dim)
if emb_type == 'ocr':
self.proj = nn.Sequential(
zero_module(linear(40*64, token_dim)),
nn.LayerNorm(token_dim)
)
if emb_type == 'conv':
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
self.placeholder_token = get_token_for_string(placeholder_string)
def encode_text(self, text_info):
if self.get_recog_emb is None and self.emb_type == 'ocr':
self.get_recog_emb = partial(get_recog_emb, self.recog)
gline_list = []
pos_list = []
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
for j in range(n_lines): # line
gline_list += [text_info['gly_line'][j][i:i+1]]
if self.add_pos:
pos_list += [text_info['positions'][j][i:i+1]]
if len(gline_list) > 0:
if self.emb_type == 'ocr':
recog_emb = self.get_recog_emb(gline_list)
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
elif self.emb_type == 'vit':
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
elif self.emb_type == 'conv':
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
if self.add_pos:
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
enc_glyph = enc_glyph+enc_pos
self.text_embs_all = []
n_idx = 0
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
text_embs = []
for j in range(n_lines): # line
text_embs += [enc_glyph[n_idx:n_idx+1]]
n_idx += 1
self.text_embs_all += [text_embs]
def forward(
self,
tokenized_text,
embedded_text,
):
b, device = tokenized_text.shape[0], tokenized_text.device
for i in range(b):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
print('truncation for log images...')
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
print('truncation for long caption...')
embedded_text[i][idx] = text_emb[:sum(idx)]
return embedded_text
def embedding_parameters(self):
return self.parameters()
import torch
import einops
import ldm.modules.encoders.modules
import ldm.modules.attention
from transformers import logging
from ldm.modules.attention import default
def disable_verbosity():
logging.set_verbosity_error()
print('logging improved.')
return
def enable_sliced_attention():
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return
def hack_everything(clip_skip=0):
disable_verbosity()
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return
# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
BOS = self.tokenizer.bos_token_id
def tokenize(t):
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
def transformer_encode(t):
if self.clip_skip > 1:
rt = self.transformer(input_ids=t, output_hidden_states=True)
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
else:
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
def split(x):
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
raw_tokens_list = tokenize(text)
tokens_list = []
for raw_tokens in raw_tokens_list:
raw_tokens_123 = split(raw_tokens)
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
tokens_list.append(raw_tokens_123)
tokens_list = torch.IntTensor(tokens_list).to(self.device)
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
y = transformer_encode(feed)
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
return z
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer
del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.distributed import rank_zero_only
class ImageLogger(Callback):
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "image_log", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
return check_idx % self.batch_freq == 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled:
self.log_img(pl_module, batch, batch_idx, split="train")
import os
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path, cond_stage_path=None, use_fp16=False):
config = OmegaConf.load(config_path)
if cond_stage_path:
config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
if use_fp16:
config.model.params.use_fp16 = True
config.model.params.control_stage_config.params.use_fp16 = True
config.model.params.unet_config.params.use_fp16 = True
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
name: anytext
channels:
- pytorch
- defaults
dependencies:
- python=3.10.6
- pip=23.0.1
- cudatoolkit=11.7
- numpy=1.23.3
- cython==0.29.33
- pip:
- Pillow==9.5.0
- gradio==3.50.0
- albumentations==0.4.3
- opencv-python==4.7.0.72
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.5.0
- omegaconf==2.2.3
- test-tube==0.7.5
- streamlit==1.20.0
- einops==0.4.1
- transformers==4.30.2
- webdataset==0.2.5
- kornia==0.6.7
- open_clip_torch==2.7.0
- torchmetrics==0.11.4
- timm==0.6.7
- addict==2.4.0
- yapf==0.32.0
- safetensors==0.4.0
- basicsr==1.4.2
- jieba==0.42.1
- modelscope==1.10.0
- tensorflow==2.13.0
- torch==2.0.1
- torchvision==0.15.2
- easydict==1.10
- xformers==0.0.20
- subword-nmt==0.3.8
- sacremoses==0.0.53
- sentencepiece==0.1.99
- fsspec
- diffusers==0.10.2
- ujson
\ No newline at end of file
import os
import shutil
import copy
import argparse
import pathlib
import json
def load(file_path: str):
file_path = pathlib.Path(file_path)
func_dict = {'.json': load_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](file_path)
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
def save(data, file_path):
file_path = pathlib.Path(file_path)
func_dict = {'.json': save_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](data, file_path)
def save_json(data, file_path):
with open(file_path, 'w', encoding='utf-8') as json_file:
json.dump(data, json_file, ensure_ascii=False, indent=4)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default='models/anytext_v1.1.ckpt',
help='path of model'
)
parser.add_argument(
"--gpus",
type=str,
default='0,1,2,3,4,5,6,7',
help='gpus for inference'
)
parser.add_argument(
"--output_dir",
type=str,
default='./anytext_v1.1_laion_generated/',
help="output path"
)
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json',
help="json path for evaluation dataset"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
ckpt_path = args.model_path
gpus = args.gpus
output_dir = args.output_dir
json_path = args.json_path
USING_DLC = False
if USING_DLC:
json_path = json_path.replace('/data/vdb', '/mnt/data', 1)
output_dir = output_dir.replace('/data/vdb', '/mnt/data', 1)
exec_path = './eval/anytext_singleGPU.py'
continue_gen = True # if True, not clear output_dir, and generate rest images.
tmp_dir = './tmp_dir'
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
if not continue_gen:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.system('sleep 1')
gpu_ids = [int(i) for i in gpus.split(',')]
nproc = len(gpu_ids)
all_lines = load(json_path)
split_file = []
length = len(all_lines['data_list']) // nproc
cmds = []
for i in range(nproc):
start, end = i*length, (i+1)*length
if i == nproc - 1:
end = len(all_lines['data_list'])
temp_lines = copy.deepcopy(all_lines)
temp_lines['data_list'] = temp_lines['data_list'][start:end]
tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json')
save(temp_lines, tmp_file)
os.system('sleep 1')
cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --input_json {tmp_file} --output_dir {output_dir} --ckpt_path {ckpt_path} && echo proc-{i} done!']
cmds = ' & '.join(cmds)
os.system(cmds)
print('Done.')
os.system('sleep 2')
shutil.rmtree(tmp_dir)
'''
command to kill the task after running:
$ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9
'''
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import cv2
import einops
import numpy as np
import torch
import random
from PIL import ImageFont
from pytorch_lightning import seed_everything
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from t3_dataset import draw_glyph, draw_glyph2, get_caption_pos
from dataset_util import load
from tqdm import tqdm
import argparse
import time
save_memory = False
# parameters
config_yaml = './models_yaml/anytext_sd15.yaml'
ckpt_path = './models/anytext_v1.0.ckpt'
json_path = '/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json'
output_dir = '/data/vdb/yuxiang.tyx/AIGC/eval/gen_imgs_test'
num_samples = 4
image_resolution = 512
strength = 1.0
ddim_steps = 20
scale = 9.0
seed = 100
eta = 0.0
a_prompt = 'best quality, extremely detailed'
n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, watermark'
PLACE_HOLDER = '*'
max_chars = 20
max_lines = 20
font = ImageFont.truetype('./font/Arial_Unicode.ttf', size=60)
def parse_args():
parser = argparse.ArgumentParser(description='generate images')
parser.add_argument('--input_json', type=str, default=json_path)
parser.add_argument('--output_dir', type=str, default=output_dir)
parser.add_argument('--ckpt_path', type=str, default=ckpt_path)
args = parser.parse_args()
return args
def arr2tensor(arr, bs):
arr = np.transpose(arr, (2, 0, 1))
_arr = torch.from_numpy(arr.copy()).float().cuda()
_arr = torch.stack([_arr for _ in range(bs)], dim=0)
return _arr
def load_data(input_path):
content = load(input_path)
d = []
count = 0
for gt in content['data_list']:
info = {}
info['img_name'] = gt['img_name']
info['caption'] = gt['caption']
if PLACE_HOLDER in info['caption']:
count += 1
info['caption'] = info['caption'].replace(PLACE_HOLDER, " ")
if 'annotations' in gt:
polygons = []
texts = []
pos = []
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0:
continue
if annotation['valid'] is False:
continue
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
pos.append(annotation['pos'])
info['polygons'] = [np.array(i) for i in polygons]
info['texts'] = texts
info['pos'] = pos
d.append(info)
print(f'{input_path} loaded, imgs={len(d)}')
if count > 0:
print(f"Found {count} image's caption contain placeholder: {PLACE_HOLDER}, change to ' '...")
return d
def draw_pos(ploygon, prob=1.0):
img = np.zeros((512, 512, 1))
if random.random() < prob:
pts = ploygon.reshape((-1, 1, 2))
cv2.fillPoly(img, [pts], color=255)
return img/255.
def get_item(data_list, item):
item_dict = {}
cur_item = data_list[item]
item_dict['img_name'] = cur_item['img_name']
item_dict['caption'] = cur_item['caption']
item_dict['glyphs'] = []
item_dict['gly_line'] = []
item_dict['positions'] = []
item_dict['texts'] = []
texts = cur_item.get('texts', [])
if len(texts) > 0:
sel_idxs = [i for i in range(len(texts))]
if len(texts) > max_lines:
sel_idxs = sel_idxs[:max_lines]
pos_idxs = [cur_item['pos'][i] for i in sel_idxs]
item_dict['caption'] = get_caption_pos(item_dict['caption'], pos_idxs, 0.0, PLACE_HOLDER)
item_dict['polygons'] = [cur_item['polygons'][i] for i in sel_idxs]
item_dict['texts'] = [cur_item['texts'][i][:max_chars] for i in sel_idxs]
# glyphs
for idx, text in enumerate(item_dict['texts']):
gly_line = draw_glyph(font, text)
glyphs = draw_glyph2(font, text, item_dict['polygons'][idx], scale=2)
item_dict['glyphs'] += [glyphs]
item_dict['gly_line'] += [gly_line]
# mask_pos
for polygon in item_dict['polygons']:
item_dict['positions'] += [draw_pos(polygon, 1.0)]
fill_caption = False
if fill_caption: # if using embedding_manager, DO NOT fill caption!
for i in range(len(item_dict['texts'])):
r_txt = item_dict['texts'][i]
item_dict['caption'] = item_dict['caption'].replace(PLACE_HOLDER, f'"{r_txt}"', 1)
# padding
n_lines = min(len(texts), max_lines)
item_dict['n_lines'] = n_lines
n_pad = max_lines - n_lines
if n_pad > 0:
item_dict['glyphs'] += [np.zeros((512*2, 512*2, 1))] * n_pad
item_dict['gly_line'] += [np.zeros((80, 512, 1))] * n_pad
item_dict['positions'] += [np.zeros((512, 512, 1))] * n_pad
return item_dict
def process(model, ddim_sampler, item_dict, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, eta):
with torch.no_grad():
prompt = item_dict['caption']
n_lines = item_dict['n_lines']
pos_imgs = item_dict['positions']
glyphs = item_dict['glyphs']
gly_line = item_dict['gly_line']
hint = np.sum(pos_imgs, axis=0).clip(0, 1)
H, W, = (512, 512)
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
if save_memory:
model.low_vram_shift(is_diffusing=False)
info = {}
info['glyphs'] = []
info['gly_line'] = []
info['positions'] = []
info['n_lines'] = [n_lines]*num_samples
for i in range(n_lines):
glyph = glyphs[i]
pos = pos_imgs[i]
gline = gly_line[i]
info['glyphs'] += [arr2tensor(glyph, num_samples)]
info['gly_line'] += [arr2tensor(gline, num_samples)]
info['positions'] += [arr2tensor(pos, num_samples)]
# get masked_x
ref_img = np.ones((H, W, 3)) * 127.5
masked_img = ((ref_img.astype(np.float32) / 127.5) - 1.0)*(1-hint)
masked_img = np.transpose(masked_img, (2, 0, 1))
masked_img = torch.from_numpy(masked_img.copy()).float().cuda()
encoder_posterior = model.encode_first_stage(masked_img[None, ...])
masked_x = model.get_first_stage_encoding(encoder_posterior).detach()
info['masked_x'] = torch.cat([masked_x for _ in range(num_samples)], dim=0)
hint = arr2tensor(hint, num_samples)
cond = model.get_learned_conditioning(dict(c_concat=[hint], c_crossattn=[[prompt + ', ' + a_prompt] * num_samples], text_info=info))
un_cond = model.get_learned_conditioning(dict(c_concat=[hint], c_crossattn=[[n_prompt] * num_samples], text_info=info))
shape = (4, H // 8, W // 8)
if save_memory:
model.low_vram_shift(is_diffusing=True)
model.control_scales = ([strength] * 13)
tic = time.time()
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
cost = (time.time() - tic)*1000.
if save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
results += [cost]
return results
if __name__ == '__main__':
args = parse_args()
total = 21
times = []
data_list = load_data(args.input_json)
model = create_model(config_yaml).cuda()
model.load_state_dict(load_state_dict(args.ckpt_path, location='cuda'), strict=True)
ddim_sampler = DDIMSampler(model)
for i in tqdm(range(len(data_list)), desc='generator'):
item_dict = get_item(data_list, i)
img_name = item_dict['img_name'].split('.')[0] + '_3.jpg'
if os.path.exists(os.path.join(args.output_dir, img_name)):
continue
results = process(model, ddim_sampler, item_dict, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed, eta)
times += [results.pop()]
if i == total:
print(times)
times = times[1:]
print(f'{np.mean(times)}')
for idx, img in enumerate(results):
img_name = item_dict['img_name'].split('.')[0]+f'_{idx}' + '.jpg'
cv2.imwrite(os.path.join(args.output_dir, img_name), img[..., ::-1])
import os
import shutil
import copy
import argparse
import pathlib
import json
def load(file_path: str):
file_path = pathlib.Path(file_path)
func_dict = {'.json': load_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](file_path)
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
def save(data, file_path):
file_path = pathlib.Path(file_path)
func_dict = {'.json': save_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](data, file_path)
def save_json(data, file_path):
with open(file_path, 'w', encoding='utf-8') as json_file:
json.dump(data, json_file, ensure_ascii=False, indent=4)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default='/home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth',
help='path of model'
)
parser.add_argument(
"--gpus",
type=str,
default='0,1,2,3,4,5,6,7',
help='gpus for inference'
)
parser.add_argument(
"--output_dir",
type=str,
default='./controlnet_laion_generated/',
help="output path"
)
parser.add_argument(
"--glyph_dir",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion',
help="path of glyph images from anytext evaluation dataset"
)
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json',
help="json path for evaluation dataset"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
output_dir = args.output_dir
tmp_dir = './tmp_dir'
exec_path = './controlnet_singleGPU.py'
continue_gen = True # if True, not clear output_dir, and generate rest images.
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
if not continue_gen:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.system('sleep 1')
gpu_ids = [int(i) for i in args.gpus.split(',')]
nproc = len(gpu_ids)
all_lines = load(args.json_path)
split_file = []
length = len(all_lines['data_list']) // nproc
cmds = []
for i in range(nproc):
start, end = i*length, (i+1)*length
if i == nproc - 1:
end = len(all_lines['data_list'])
temp_lines = copy.deepcopy(all_lines)
temp_lines['data_list'] = temp_lines['data_list'][start:end]
tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json')
save(temp_lines, tmp_file)
os.system('sleep 1')
cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!']
cmds = ' & '.join(cmds)
os.system(cmds)
print('Done.')
os.system('sleep 2')
shutil.rmtree(tmp_dir)
'''
command to kill the task after running:
$ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9
'''
'''
Part of the implementation is borrowed and modified from ControlNet, publicly available at https://github.com/lllyasviel/ControlNet/blob/main/gradio_canny2image.py
'''
from share import *
import config
import cv2
import einops
import numpy as np
import torch
import random
from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler
from PIL import Image
import os
import json
from tqdm import tqdm
import argparse
def parse_args():
parser = argparse.ArgumentParser()
# specify the inference settings
parser.add_argument(
"--model_path",
type=str,
default='/home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth',
help='path of model'
)
parser.add_argument(
"--num_samples",
type=int,
default=4,
help="how many samples to produce for each given prompt. A.k.a batch size",
)
parser.add_argument(
"--a_prompt",
type=str,
default='best quality, extremely detailed',
help="additional prompt"
)
parser.add_argument(
"--n_prompt",
type=str,
default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, watermark',
help="negative prompt"
)
parser.add_argument(
"--image_resolution",
type=int,
default=512,
help="image resolution",
)
parser.add_argument(
"--strength",
type=float,
default=1,
help="control strength",
)
parser.add_argument(
"--scale",
type=float,
default=9.0,
help="classifier-free guidance scale",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=20,
help="ddim steps",
)
parser.add_argument(
"--seed",
type=int,
default=100,
help="seed",
)
parser.add_argument(
"--guess_mode",
action="store_true",
help="whether use guess mode",
)
parser.add_argument(
"--eta",
type=float,
default=0,
help="eta",
)
parser.add_argument(
"--low_threshold",
type=int,
default=100,
help="low threshold",
)
parser.add_argument(
"--high_threshold",
type=int,
default=200,
help="high threshold",
)
parser.add_argument(
"--output_dir",
type=str,
default='./controlnet_laion_generated/',
help="output path"
)
parser.add_argument(
"--glyph_dir",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion',
help="path of glyph images from anytext evaluation dataset"
)
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json',
help="json path for evaluation dataset"
)
args = parser.parse_args()
return args
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold, model, ddim_sampler):
with torch.no_grad():
img = resize_image(HWC3(input_image), image_resolution)
H, W, C = img.shape
detected_map = apply_canny(img, low_threshold, high_threshold)
detected_map = HWC3(detected_map)
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)
if config.save_memory:
model.low_vram_shift(is_diffusing=True)
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, mask=None,
x0=None, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
return results
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
def load_data(input_path):
content = load_json(input_path)
d = []
count = 0
for gt in content['data_list']:
info = {}
info['img_name'] = gt['img_name']
info['caption'] = gt['caption']
if PLACE_HOLDER in info['caption']:
count += 1
info['caption'] = info['caption'].replace(PLACE_HOLDER, " ")
if 'annotations' in gt:
polygons = []
texts = []
pos = []
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0:
continue
if annotation['valid'] is False:
continue
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
pos.append(annotation['pos'])
info['polygons'] = [np.array(i) for i in polygons]
info['texts'] = texts
info['pos'] = pos
d.append(info)
print(f'{input_path} loaded, imgs={len(d)}')
if count > 0:
print(f"Found {count} image's caption contain placeholder: {PLACE_HOLDER}, change to ' '...")
return d
def get_item(data_list, item):
item_dict = {}
cur_item = data_list[item]
item_dict['img_name'] = cur_item['img_name']
item_dict['caption'] = cur_item['caption']
return item_dict
if __name__ == "__main__":
args = parse_args()
apply_canny = CannyDetector()
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict(args.model_path, location='cuda'), strict=False)
model = model.cuda()
ddim_sampler = DDIMSampler(model)
if os.path.exists(args.output_dir) is not True:
os.makedirs(args.output_dir)
PLACE_HOLDER = '*'
data_list = load_data(args.json_path)
for i in tqdm(range(len(data_list)), desc='generator'):
item_dict = get_item(data_list, i)
p = item_dict['img_name']
img_name = item_dict['img_name'].split('.')[0] + '_3.jpg'
if os.path.exists(os.path.join(args.output_dir, img_name)):
continue
input_image_path = os.path.join(args.glyph_dir, p)
prompt = item_dict['caption']
img = Image.open(input_image_path)
input_image = np.array(img)
results = process(input_image, prompt, args.a_prompt, args.n_prompt, args.num_samples, args.image_resolution, args.ddim_steps, args.scale, args.seed, args.eta, args.low_threshold, args.high_threshold, model, ddim_sampler)
for idx, img in enumerate(results):
img_name = item_dict['img_name'].split('.')[0]+f'_{idx}' + '.jpg'
cv2.imwrite(os.path.join(args.output_dir, img_name), img[..., ::-1])
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import cv2
from cldm.recognizer import TextRecognizer, crop_image
from easydict import EasyDict as edict
from anytext_singleGPU import load_data, get_item
from tqdm import tqdm
import os
import torch
import Levenshtein
import numpy as np
import math
import argparse
PRINT_DEBUG = False
num_samples = 4
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_dir",
type=str,
default='/home/yuxiang.tyx/projects/ControlNet/controlnet_wukong_generated',
help='path of generated images for eval'
)
parser.add_argument(
"--input_json",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json',
help='json path for evaluation dataset'
)
args = parser.parse_args()
return args
args = parse_args()
img_dir = args.img_dir
input_json = args.input_json
if 'wukong' in input_json:
model_lang = 'ch'
rec_char_dict_path = os.path.join('./ocr_weights', 'ppocr_keys_v1.txt')
elif 'laion' in input_json:
rec_char_dict_path = os.path.join('./ocr_weights', 'en_dict.txt')
def get_ld(ls1, ls2):
edit_dist = Levenshtein.distance(ls1, ls2)
return 1 - edit_dist/(max(len(ls1), len(ls2)) + 1e-5)
def pre_process(img_list, shape):
numpy_list = []
img_num = len(img_list)
assert img_num > 0
for idx in range(0, img_num):
# rotate
img = img_list[idx]
h, w = img.shape[1:]
if h > w * 1.2:
img = torch.transpose(img, 1, 2).flip(dims=[1])
img_list[idx] = img
h, w = img.shape[1:]
# resize
imgC, imgH, imgW = (int(i) for i in shape.strip().split(','))
assert imgC == img.shape[0]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = torch.nn.functional.interpolate(
img.unsqueeze(0),
size=(imgH, resized_w),
mode='bilinear',
align_corners=True,
)
# padding
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32)
padding_im[:, :, 0:resized_w] = resized_image[0]
numpy_list += [padding_im.permute(1, 2, 0).cpu().numpy()] # HWC ,numpy
return numpy_list
def main():
predictor = pipeline(Tasks.ocr_recognition, model='damo/cv_convnextTiny_ocr-recognition-general_damo')
rec_image_shape = "3, 48, 320"
args = edict()
args.rec_image_shape = rec_image_shape
args.rec_char_dict_path = rec_char_dict_path
args.rec_batch_num = 1
args.use_fp16 = False
text_recognizer = TextRecognizer(args, None)
data_list = load_data(input_json)
sen_acc = []
edit_dist = []
for i in tqdm(range(len(data_list)), desc='evaluate'):
item_dict = get_item(data_list, i)
img_name = item_dict['img_name'].split('.')[0]
n_lines = item_dict['n_lines']
for j in range(num_samples):
img_path = os.path.join(img_dir, img_name+f'_{j}.jpg')
img = cv2.imread(img_path)
if PRINT_DEBUG:
cv2.imwrite(f'{i}_{j}.jpg', img)
img = torch.from_numpy(img)
img = img.permute(2, 0, 1).float() # HWC-->CHW
gt_texts = []
pred_texts = []
for k in range(n_lines): # line
gt_texts += [item_dict['texts'][k]]
np_pos = (item_dict['positions'][k]*255.).astype(np.uint8) # 0-1, hwc
pred_text = crop_image(img, np_pos)
pred_texts += [pred_text]
if n_lines > 0:
pred_texts = pre_process(pred_texts, rec_image_shape)
preds_all = []
for idx, pt in enumerate(pred_texts):
if PRINT_DEBUG:
cv2.imwrite(f'{i}_{j}_{idx}.jpg', pt)
rst = predictor(pt)
preds_all += [rst['text'][0]]
for k in range(len(preds_all)):
pred_text = preds_all[k]
gt_order = [text_recognizer.char2id.get(m, len(text_recognizer.chars)-1) for m in gt_texts[k]]
pred_order = [text_recognizer.char2id.get(m, len(text_recognizer.chars)-1) for m in pred_text]
if pred_text == gt_texts[k]:
sen_acc += [1]
else:
sen_acc += [0]
edit_dist += [get_ld(pred_order, gt_order)]
if PRINT_DEBUG:
print(f'pred/gt="{pred_text}"/"{gt_texts[k]}", ed={edit_dist[-1]:.4f}')
print(f'Done, lines={len(sen_acc)}, sen_acc={np.array(sen_acc).mean():.4f}, edit_dist={np.array(edit_dist).mean():.4f}')
if __name__ == "__main__":
main()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python -m pytorch_fid \
/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/fid/wukong-40k \
/data/vdc/yuxiang.tyx/AIGC/anytext_eval_imgs/controlnet_wukong_generated
\ No newline at end of file
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python eval/eval_dgocr.py \
--img_dir /data/vdc/yuxiang.tyx/AIGC/anytext_eval_imgs/controlnet_wukong_generated \
--input_json /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json
\ No newline at end of file
#!/bin/bash
python eval/render_glyph_imgs.py \
--json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \
--output_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion
\ No newline at end of file
#!/bin/bash
python eval/anytext_multiGPUs.py \
--model_path models/anytext_v1.1.ckpt \
--json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \
--output_dir ./anytext_laion_generated \
--gpus 0,1,2,3,4,5,6,7
#!/bin/bash
python controlnet_multiGPUs.py \
--model_path /home/yuxiang.tyx/projects/AnyText/models/control_sd15_canny.pth \
--json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \
--glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \
--output_dir ./controlnet_wukong_generated \
--gpus 0,1,2,3,4,5,6,7
#!/bin/bash
python glyphcontrol_multiGPUs.py \
--model_path checkpoints/laion10M_epoch_6_model_ema_only.ckpt \
--json_path /data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json \
--glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion \
--output_dir ./glyphcontrol_laion_generated \
--gpus 0,1,2,3,4,5,6,7
#!/bin/bash
python textdiffuser_multiGPUs.py \
--model_path textdiffuser-ckpt/diffusion_backbone \
--json_path /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json \
--glyph_dir /data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong \
--output_dir ./textdiffuser_wukong_generated \
--gpus 0,1,2,3,4,5,6,7
import os
import shutil
import copy
import argparse
import pathlib
import json
def load(file_path: str):
file_path = pathlib.Path(file_path)
func_dict = {'.json': load_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](file_path)
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
def save(data, file_path):
file_path = pathlib.Path(file_path)
func_dict = {'.json': save_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](data, file_path)
def save_json(data, file_path):
with open(file_path, 'w', encoding='utf-8') as json_file:
json.dump(data, json_file, ensure_ascii=False, indent=4)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default='checkpoints/laion10M_epoch_6_model_ema_only.ckpt',
help='path to checkpoint of model'
)
parser.add_argument(
"--gpus",
type=str,
default='0,1,2,3,4,5,6,7',
help='gpus for inference'
)
parser.add_argument(
"--output_dir",
type=str,
default='./glyphcontrol_laion_generated/',
help="output path"
)
parser.add_argument(
"--glyph_dir",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion',
help="path of glyph images from anytext evaluation dataset"
)
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json',
help="json path for evaluation dataset"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
output_dir = args.output_dir
tmp_dir = './tmp_dir'
exec_path = './glyphcontrol_singleGPU.py'
continue_gen = True # if True, not clear output_dir, and generate rest images.
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
if not continue_gen:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.system('sleep 1')
gpu_ids = [int(i) for i in args.gpus.split(',')]
nproc = len(gpu_ids)
all_lines = load(args.json_path)
split_file = []
length = len(all_lines['data_list']) // nproc
cmds = []
for i in range(nproc):
start, end = i*length, (i+1)*length
if i == nproc - 1:
end = len(all_lines['data_list'])
temp_lines = copy.deepcopy(all_lines)
temp_lines['data_list'] = temp_lines['data_list'][start:end]
tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json')
save(temp_lines, tmp_file)
os.system('sleep 1')
cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!']
cmds = ' & '.join(cmds)
os.system(cmds)
print('Done.')
os.system('sleep 2')
shutil.rmtree(tmp_dir)
'''
command to kill the task after running:
$ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9
'''
This diff is collapsed.
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tqdm import tqdm
import shutil
import numpy as np
import cv2
from PIL import Image, ImageFont
from torch.utils.data import DataLoader
from dataset_util import show_bbox_on_image
import argparse
from t3_dataset import T3DataSet
max_lines = 20
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/test1k.json',
help="json path for evaluation dataset",
)
parser.add_argument(
"--output_dir",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/wukong_word/glyph_wukong',
help="output path, clear the folder if exist",
)
parser.add_argument(
"--img_count",
type=int,
default=1000,
help="image count",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if os.path.exists(args.output_dir):
shutil.rmtree(args.output_dir)
os.makedirs(args.output_dir)
dataset = T3DataSet(args.json_path, for_show=True, max_lines=max_lines, glyph_scale=2, mask_img_prob=1.0, caption_pos_prob=0.0)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
pbar = tqdm(total=args.img_count)
for i, data in enumerate(train_loader):
if i == args.img_count:
break
all_glyphs = []
for k, glyphs in enumerate(data['glyphs']):
all_glyphs += [glyphs[0].numpy().astype(np.int32)*255]
glyph_img = cv2.resize(255.0-np.sum(all_glyphs, axis=0), (512, 512))
cv2.imwrite(os.path.join(args.output_dir, data['img_name'][0]), glyph_img)
pbar.update(1)
pbar.close()
import os
import shutil
import copy
import argparse
import pathlib
import json
def load(file_path: str):
file_path = pathlib.Path(file_path)
func_dict = {'.json': load_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](file_path)
def load_json(file_path: str):
with open(file_path, 'r', encoding='utf8') as f:
content = json.load(f)
return content
def save(data, file_path):
file_path = pathlib.Path(file_path)
func_dict = {'.json': save_json}
assert file_path.suffix in func_dict
return func_dict[file_path.suffix](data, file_path)
def save_json(data, file_path):
with open(file_path, 'w', encoding='utf-8') as json_file:
json.dump(data, json_file, ensure_ascii=False, indent=4)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default='textdiffuser-ckpt/diffusion_backbone',
help='path to model'
)
parser.add_argument(
"--gpus",
type=str,
default='0,1,2,3,4,5,6,7',
help='gpus for inference'
)
parser.add_argument(
"--output_dir",
type=str,
default='./textdiffuser_laion_generated/',
help="output path"
)
parser.add_argument(
"--glyph_dir",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/glyph_laion',
help="path of glyph images from anytext evaluation dataset"
)
parser.add_argument(
"--json_path",
type=str,
default='/data/vdb/yuxiang.tyx/AIGC/data/laion_word/test1k.json',
help="json path for evaluation dataset"
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
output_dir = args.output_dir
tmp_dir = './tmp_dir'
exec_path = './textdiffuser_singleGPU.py'
continue_gen = True # if True, not clear output_dir, and generate rest images.
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
if not continue_gen:
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
os.system('sleep 1')
gpu_ids = [int(i) for i in args.gpus.split(',')]
nproc = len(gpu_ids)
all_lines = load(args.json_path)
split_file = []
length = len(all_lines['data_list']) // nproc
cmds = []
for i in range(nproc):
start, end = i*length, (i+1)*length
if i == nproc - 1:
end = len(all_lines['data_list'])
temp_lines = copy.deepcopy(all_lines)
temp_lines['data_list'] = temp_lines['data_list'][start:end]
tmp_file = os.path.join(tmp_dir, f'tmp_list_{i}.json')
save(temp_lines, tmp_file)
os.system('sleep 1')
cmds += [f'export CUDA_VISIBLE_DEVICES={gpu_ids[i]} && python {exec_path} --json_path {tmp_file} --output_dir {output_dir} --model_path {args.model_path} --glyph_dir {args.glyph_dir} && echo proc-{i} done!']
cmds = ' & '.join(cmds)
os.system(cmds)
print('Done.')
os.system('sleep 2')
shutil.rmtree(tmp_dir)
'''
command to kill the task after running:
$ps -ef | grep singleGPU | awk '{ print $2 }' | xargs kill -9 && ps -ef | grep multiproce | awk '{ print $2 }' | xargs kill -9
'''
This diff is collapsed.
from modelscope.pipelines import pipeline
from util import save_images
pipe = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.3')
img_save_folder = "SaveImages"
params = {
"show_debug": True,
"image_count": 2,
"ddim_steps": 20,
}
# 1. text generation
mode = 'text-generation'
input_data = {
"prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream',
"seed": 66273235,
"draw_pos": 'example_images/gen9.png'
}
results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
if rtn_code >= 0:
save_images(results, img_save_folder)
print(f'Done, result images are saved in: {img_save_folder}')
if rtn_warning:
print(rtn_warning)
# 2. text editing
mode = 'text-editing'
input_data = {
"prompt": 'A cake with colorful characters that reads "EVERYDAY"',
"seed": 8943410,
"draw_pos": 'example_images/edit7.png',
"ori_image": 'example_images/ref7.jpg'
}
results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
if rtn_code >= 0:
save_images(results, img_save_folder)
print(f'Done, result images are saved in: {img_save_folder}')
if rtn_warning:
print(rtn_warning)
This diff is collapsed.
import torch
from ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
\ No newline at end of file
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
This diff is collapsed.
This diff is collapsed.
from .sampler import DPMSolverSampler
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment