import json, os, string, random, time, pickle, gc
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models

from PIL import Image
import matplotlib.pyplot as plt

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import pdb, string, sys
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk import pos_tag

lem = WordNetLemmatizer()
swords = set(stopwords.words('english'))
workdir = './'

# load annotation data
cocoTrainCaption = json.load(open("./captions_train2014.json"))
cocoValCaption = json.load(open("./captions_val2014.json"))
cocoTrainInstances = json.load(open("./instances_train2014.json"))
cocoValInstances = json.load(open("./instances_val2014.json"))

print("size of caption train data: ", len(cocoTrainCaption['images']))
print("size of instance train data: ", len(cocoTrainInstances['images']))
print("size of caption val data: ", len(cocoValCaption['images']))
print("size of instance val data: ", len(cocoValCaption['images']))
print("keys in caption data: ", cocoValCaption.keys())
print("\n")
print("keys in instance data: ", cocoValInstances.keys())
print("\n")
print("first image in caption data: ", cocoValCaption['images'][0])
print("\n")
print("first annotation in caption data: ", cocoValCaption['annotations'][0])
print("\n")
print("first image in instance data: ", cocoValInstances['images'][0])
print("\n")
print("first annotation in instance data: ", cocoValInstances['annotations'][0])


object2id = {} # id != index
id2object = {}
for category in cocoValInstances['categories']:
    id2object[category['id']] = category['name']
    object2id[category['name']] = category['id'] 


pickle.dump({'object2id': object2id, 'id2object': id2object}, open(os.path.join(workdir, 'objectInfo.p'), 'wb'))

Images = cocoTrainCaption['images'] + cocoValCaption['images']
captionAnnotations = cocoTrainCaption['annotations'] + cocoValCaption['annotations']
instanceAnnotations = cocoTrainInstances['annotations'] + cocoValInstances['annotations']

from collections import defaultdict
image2objects = defaultdict(list)
for annotation in instanceAnnotations:
    image2objects[annotation['image_id']].append(annotation['category_id'])
print image2objects[262145]


for image in Images:
    image['objetcs'] = image2objects[image['id']]
    file_name = image['file_name']
    if 'val' in file_name:
        image['file_path'] = os.path.join("/localtmp/data/coco/val2014/", file_name)
    else:
        image['file_path'] = os.path.join("/localtmp/data/coco/train2014", file_name)
print(Images[0])

imageId2annotation = defaultdict(list)
for annotation in captionAnnotations:
    imageId2annotation[annotation['image_id']].append(annotation['caption'])


from copy import deepcopy
trainData = list()
valData = list()
testData = list()
from random import shuffle
shuffle(Images) # shuffle data
def add_caption(dataList, image):
    for caption in imageId2annotation[image['id']]:
            tmp = deepcopy(image)
            tmp['caption'] = caption
            dataList.append(tmp)
for idx, image in enumerate(Images):
    if idx < 5000:
#         add_caption(valData, image)
        image['caption'] = imageId2annotation[image['id']][0]
        valData.append(image)
    elif idx < 10000:
#         add_caption(testData, image)
        image['caption'] = imageId2annotation[image['id']][0]
        testData.append(image)
    else:
        add_caption(trainData, image)
            

            
print("trainData size {}, valData size {}, testData size {}".format(len(trainData), len(valData), len(testData)))
pickle.dump(trainData, open(os.path.join(workdir, 'train_split.p'), 'wb'))
pickle.dump(valData, open(os.path.join(workdir, 'val_split.p'), 'wb'))
pickle.dump(testData, open(os.path.join(workdir, 'test_split.p'), 'wb'))