图像的标签在一个json文件中。
%matplotlib inlineimport jsonimport gluonbook as gbimport mxnet as mxfrom mxnet import autograd, gluon, image, init, ndfrom mxnet.gluon import data as gdata, loss as gloss, utils as gutilsimport sysfrom time import timetrain_Pedestrian_url = []train_Cyclist_url = []train_Others_url = []with open('instances.json',encoding='utf-8') as f: for _ in range(100000): if len(train_Pedestrian_url) + len(train_Cyclist_url) + len(train_Others_url) >= 300: break line = f.readline() js = json.loads(line) if js['attrs']['ignore']=='yes' or js['attrs']['occlusion']=='heavily_occluded' or js['attrs']['occlusion']=='invisible': continue if js['attrs']['type'] == 'Pedestrian': if len(train_Pedestrian_url) >=100: continue train_Pedestrian_url.append(js['thumbnail_path']) elif js['attrs']['type'] == 'Cyclist': if len(train_Cyclist_url) >=100: continue train_Cyclist_url.append(js['thumbnail_path']) elif js['attrs']['type'] == 'Others': if len(train_Others_url) >=100: continue train_Others_url.append(js['thumbnail_path']) # img = image.imread(url) f.close()print(train_Cyclist_url)print(len(train_Pedestrian_url),len(train_Cyclist_url),len(train_Others_url))img = image.imread('/mnt/hdfs-data-4/data/'+train_Cyclist_url[0])img.astype('float32')labels = nd.zeros(shape=(30000,))labels[10000:20000] = 1labels[20000:] = 2
数据整理就差不多了,然后就是建网络,跑模型了。