
import json
import nltk
import json, os, string, random, time, pickle, gc

with open('captions_train2014.json') as f: train=json.load(f)
with open('captions_val2014.json') as f: valid=json.load(f)
with open('captions_test2014.json') as f: test=json.load(f)

from pdb import set_trace as stop

train_id_to_image = {}
for sample in train['images']:
    train_id_to_image[sample['id']] = sample['file_name']

train_image_to_caption = {}
train_image_id_to_caption = {}
for sample in train['annotations']: 
    image_id = sample['image_id']
    caption = sample['caption']
    image_name = train_id_to_image[image_id]
    train_image_to_caption[image_name]= caption
    train_image_id_to_caption[image_id] = caption


valid_id_to_image = {}
for sample in valid['images']:
    valid_id_to_image[sample['id']] = sample['file_name']

valid_image_to_caption = {}
valid_image_id_to_caption = {}
for sample in valid['annotations']: 
    image_id = sample['image_id']
    caption = sample['caption']
    image_name = valid_id_to_image[image_id]
    valid_image_to_caption[image_name]= caption
    valid_image_id_to_caption[image_id] = caption

test_id_to_image = {}
for sample in test['images']:
    test_id_to_image[sample['id']] = sample['file_name']

test_image_to_caption = {}
test_image_id_to_caption = {}
for sample in test['annotations']: 
    image_id = sample['image_id']
    caption = sample['caption']
    image_name = test_id_to_image[image_id]
    test_image_to_caption[image_name]= caption
    test_image_id_to_caption[image_id] = caption


val_test_captions = valid_image_to_caption
val_test_captions.update(test_image_to_caption)


# with open('train_captions.json', 'w') as fp: json.dump(train_image_to_caption, fp)
# with open('val_captions.json', 'w') as fp: json.dump(val_test_captions, fp)


# train = pickle.load(open('../train.data','rb'))
# valid = pickle.load(open('../val_test.data','rb'))


train_samples = []
train_img_file = open('../annotations/train2014_imgs.csv', 'r') 
train_img_lines = train_img_file.readlines() 
train_img_file.close()
for line in train_img_lines:
    line = line.split('\t')
    img_id = line[1]
    file_name = line[2]
    sample_dict = {}
    sample_dict['image_id'] = img_id
    sample_dict['file_name'] = file_name
    train_samples.append(sample_dict)

train_label_file = open('../annotations/train2014_img_labels.csv', 'r') 
train_label_lines = train_label_file.readlines() 
train_label_file.close()
for idx,line in enumerate(train_label_lines):
    line = line.split(',')
    for i in range(0, len(line)): 
        obj_count = int(line[i])
        if obj_count > 0:
            line[i] = 1
        else:
            line[i] = 0
    train_samples[idx]['objects'] = line

val_samples = []
val_img_file = open('../annotations/val2014_imgs.csv', 'r') 
val_img_lines = val_img_file.readlines() 
val_img_file.close()
for line in val_img_lines:
    line = line.split('\t')
    img_id = line[1]
    file_name = line[2]
    sample_dict = {}
    sample_dict['image_id'] = img_id
    sample_dict['file_name'] = file_name
    val_samples.append(sample_dict)

val_label_file = open('../annotations/val2014_img_labels.csv', 'r') 
val_label_lines = val_label_file.readlines() 
val_label_file.close()
for idx,line in enumerate(val_label_lines):
    line = line.split(',')
    for i in range(0, len(line)): 
        obj_count = int(line[i])
        if obj_count > 0:
            line[i] = 1
        else:
            line[i] = 0
    val_samples[idx]['objects'] = line


test_samples = []
test_img_file = open('../annotations/test2014_imgs.csv', 'r') 
test_img_lines = test_img_file.readlines() 
test_img_file.close()
for line in test_img_lines:
    line = line.split('\t')
    img_id = line[1]
    file_name = line[2]
    sample_dict = {}
    sample_dict['image_id'] = img_id
    sample_dict['file_name'] = file_name
    test_samples.append(sample_dict)

test_label_file = open('../annotations/test2014_img_labels.csv', 'r') 
test_label_lines = test_label_file.readlines() 
test_label_file.close()
for idx,line in enumerate(test_label_lines):
    line = line.split(',')
    for i in range(0, len(line)): 
        obj_count = int(line[i])
        if obj_count > 0:
            line[i] = 1
        else:
            line[i] = 0
    test_samples[idx]['objects'] = line

val_test_samples = val_samples+test_samples

for idx,sample in enumerate(train_samples):
    image_name = sample['file_name']
    caption = train_image_to_caption[image_name]
    train_samples[idx]['caption'] = caption

for idx,sample in enumerate(val_test_samples):
    image_name = sample['file_name']
    caption = val_test_captions[image_name]
    val_test_samples[idx]['caption'] = caption

pickle.dump( train_samples, open( '../train.data', "wb" ) )
pickle.dump( val_test_samples, open( '../val_test.data', "wb" ) )



train = json.load(open('../annotations/instances_train2014.json', 'r'))
val = json.load(open('../annotations/instances_val2014.json', 'r'))




for idx,_ in enumerate(train['annotations']):
    sample = train['annotations'][idx]
    image_id = sample['image_id']
    caption = train_image_id_to_caption[image_id]
    train['annotations'][idx]['caption'] = caption

for idx,_ in enumerate(val['annotations']):
    sample = val['annotations'][idx]
    image_id = sample['image_id']
    try:
        caption = valid_image_id_to_caption[image_id]
    except:
        caption = test_image_id_to_caption[image_id]
    val['annotations'][idx]['caption'] = caption
    

with open('../annotations/instances_train2014.json','w') as fp: json.dump(train, fp)
with open('../annotations/instances_val2014.json','w') as fp: json.dump(val, fp)
stop()






def filter(txt):
    txt = str(txt).lower().translate(string.punctuation).strip().split()
    return txt

def one_caption(args):
    idx = args[0]
    vocabulary = args[1]
    caption = args[2]
    try:
        good_words = filter(caption)
        for word in good_words:
            vocabulary[word] = vocabulary.get(word, 0) + 1
        if idx % 100000 == 0: print(idx)
    except Exception as ex:
        template = "An exception of type {0} occured. Arguments:\n{1!r}"
        message = template.format(type(ex).__name__, ex.args)
        print(message)



from multiprocessing import Pool, Manager

manager = Manager()
vocabulary = manager.dict()
process_num = 16 # change
p=Pool(processes=process_num)
pool_list=[]
counter = 0
for (idx, ann) in enumerate(train['annotations']):
    counter += 1
    pool_list.append((idx, vocabulary, ann['caption']))
p.map(one_caption, pool_list)
p.close()
p.join()
vocabulary = dict(vocabulary)
print(len(vocabulary))
vocabulary = sorted([(count, w) for w,count in vocabulary.items()], reverse=True)
print("top words and their counts: ")
print("\n".join(map(str, vocabulary[:20])))

pickle.dump(vocabulary, open(os.path.join('./','vocab_caption.p'), 'wb'))