import ast from PIL import Image import torchvision.transforms as transforms from torch.autograd import Variable import torchvision.models as models from torch import __version__ resnet18 = models.resnet18(pretrained=True) alexnet = models.alexnet(pretrained=True) vgg16 = models.vgg16(pretrained=True) models = {'resnet': resnet18, 'alexnet': alexnet, 'vgg': vgg16} # obtain ImageNet labels with open('imagenet1000_clsid_to_human.txt') as imagenet_classes_file: imagenet_classes_dict = ast.literal_eval(imagenet_classes_file.read()) def classifier(img_path, model_name): # load the image img_pil = Image.open(img_path) # define transforms preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # preprocess the image img_tensor = preprocess(img_pil) # resize the tensor (add dimension for batch) img_tensor.unsqueeze_(0) # wrap input in variable, wrap input in variable - no longer needed for # v 0.4 & higher code changed 04/26/2018 by Jennifer S. to handle PyTorch upgrade pytorch_ver = __version__.split('.') # pytorch versions 0.4 & hihger - Variable depreciated so that it returns # a tensor. So to address tensor as output (not wrapper) and to mimic the # affect of setting volatile = True (because we are using pretrained models # for inference) we can set requires_gradient to False. Here we just set # requires_grad_ to False on our tensor if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4: img_tensor.requires_grad_(False) # pytorch versions less than 0.4 - uses Variable because not-depreciated else: # apply model to input # wrap input in variable data = Variable(img_tensor, volatile = True) # apply model to input model = models[model_name] # puts model in evaluation mode # instead of (default)training mode model = model.eval() # apply data to model - adjusted based upon version to account for # operating on a Tensor for version 0.4 & higher. if int(pytorch_ver[0]) > 0 or int(pytorch_ver[1]) >= 4: output = model(img_tensor) # pytorch versions less than 0.4 else: # apply data to model output = model(data) # return index corresponding to predicted class pred_idx = output.data.numpy().argmax() return imagenet_classes_dict[pred_idx]