from torch.utils.data import dataset
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np

from torch.utils import data
from datasets import VOCSegmentation, Cityscapes, cityscapes
from torchvision import transforms as T
#from metrics import StreamSegMetrics

import torch
import torch.nn as nn

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from glob import glob

num_classes = 19
decode_fn = Cityscapes.decode_target

model = network.modeling.__dict__["deeplabv3plus_mobilenet"](num_classes=19, output_stride=16)
checkpoint = torch.load("third_party/DeepLabV3Plus-Pytorch/best_deeplabv3plus_mobilenet_cityscapes_os16.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.cuda()

model = model.eval()
#with torch.no_grad():
#    #ext = os.path.basename(img_path).split('.')[-1]
#    #img_name = os.path.basename(img_path)[:-len(ext)-1]
#    import matplotlib.pyplot as plt; 
#    img = plt.imread("img/ref/rgb_gt.png")[...,:3];#Image.open(img_path).convert('RGB')
#    #from pdb import set_trace as pdb_;pdb_() 
#    img = torch.from_numpy(np.array(img)).permute(2,0,1)[None]#transform(img).unsqueeze(0) # To tensor of NCHW
#    img = img.cuda()
#    
#    pred = model(img).max(1)[1].cpu().numpy()[0] # HW
#    from pdb import set_trace as pdb_;pdb_() 
#    colorized_preds = decode_fn(pred).astype('uint8')
#    colorized_preds = Image.fromarray(colorized_preds)
#    colorized_preds.save(os.path.join("test_results", 'rgb_gt.png'))
