This project demonstrates using PyTorch to develop an image classifier for 102 flower species. It incorporates transfer learning through networks provided by the torchvision model subpackage.
You can see full details on how the model was trained and how it makes predictions in this Jupyter notebook. Images come from this dataset (you can also download the dataset separated into categories here thanks to Udacity). This project is based on Udacity's AI Programming with Python Nanodegree.
- Python 3
- PyTorch
- TorchVision
- NumPy
You can train your own network by running the train.py
script on the command line.
Basic usage:
python train.py data_directory
Prints out training loss, validation loss, and validation accuracy as the network trains.
Options:
- Set directory to save checkpoints:
python train.py data_dir --save_dir save_directory
- Choose architecture:
python train.py data_dir --arch vgg
- Available network options:
vgg
,densenet
,alexnet
- Available network options:
- Set hyperparameters:
python train.py data_dir --learning_rate 0.01 --hidden_units 512 --epochs 20
- Use GPU for training:
python train.py data_dir --gpu
Defaults:
- Save directory: current directory
- Network:
vgg
- Learning rate:
0.01
- Hidden units:
512
- Epochs:
20
- GPU: Set to off by default
Example: python train.py flowers --save_dir checkpoints --arch densenet --epochs 30 --gpu
Given an image and a checkpoint file, you can make predictions with a previously trained network by running the predict.py
script on the command line.
For the flower identifier, I have trained a network using VGG-11 and saved a checkpoint available in the vgg11-checkpoint.pth
file.
Basic usage:
python predict.py /path/to/image checkpoint
Predicts the name of the flower and gives the probability of that species. Path to image and checkpoint file are required
Options:
- Return top
K
most likely classes:python predict.py input checkpoint --top_k 3
- Use a mapping of categories to real names:
python predict.py input checkpoint --category_names cat_to_name.json
- Use GPU for inference:
python predict.py input checkpoint --gpu
Defaults:
- Category names: none
- Top K:
1
- GPU: Set to off by default
Example: python predict.py flowers/test/52/image_04160.jpg vgg11-checkpoint.pth --top_k 5 --category_names cat_to_name.json
You could try it with the checkpoint provided in this repository and one of the images from the dataset (or your own)!