-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
620 additions
and
967 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: publish package | ||
|
||
on: | ||
release: | ||
types: [ created ] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.8' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip setuptools wheel | ||
python -m pip install twine | ||
- name: Build and publish | ||
env: | ||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} | ||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} | ||
run: | | ||
python setup.py sdist bdist_wheel | ||
twine upload dist/* |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,27 @@ | ||
name: Unit Tests | ||
name: unit tests | ||
|
||
on: [push] | ||
on: | ||
workflow_dispatch: | ||
push: | ||
schedule: | ||
- cron: "0 21 * * 6" | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.8' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
python -m pip install torch | ||
python -m pip install -e '.[dev]' | ||
python -m pip install pip --upgrade | ||
python -m pip install -r requirements.txt | ||
python -m pip install pytest hypothesis torchnyan | ||
python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps | ||
python -m pip install pytorch-crf torch-struct | ||
- name: Test with pytest | ||
run: | | ||
python -m pytest tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,24 @@ | ||
# TorchLatent | ||
<div align="center"> | ||
|
||
![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg) | ||
[![PyPI version](https://badge.fury.io/py/torchlatent.svg)](https://badge.fury.io/py/torchlatent) | ||
[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua) | ||
# TorchLatent | ||
|
||
## Requirements | ||
![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/speedcell4/torchlatent/unit-tests.yml?cacheSeconds=0) | ||
![PyPI - Version](https://img.shields.io/pypi/v/torchlatent?label=pypi%20version&cacheSeconds=0) | ||
![PyPI - Downloads](https://img.shields.io/pypi/dm/torchlatent?cacheSeconds=0) | ||
|
||
- Python 3.8 | ||
- PyTorch 1.10.2 | ||
</div> | ||
|
||
## Installation | ||
|
||
`python3 -m pip torchlatent` | ||
|
||
## Performance | ||
|
||
``` | ||
TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497 | ||
Third (0.232487) => 0.103277 0.129209 0.145311 | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
import torch | ||
from torchrua import pack_sequence | ||
|
||
from torchlatent.crf import CrfDecoder | ||
|
||
num_tags = 3 | ||
num_conjugates = 1 | ||
|
||
decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) | ||
|
||
emissions = pack_sequence([ | ||
torch.randn((5, num_conjugates, num_tags), requires_grad=True), | ||
torch.randn((2, num_conjugates, num_tags), requires_grad=True), | ||
torch.randn((3, num_conjugates, num_tags), requires_grad=True), | ||
]) | ||
|
||
tags = pack_sequence([ | ||
torch.randint(0, num_tags, (5, num_conjugates)), | ||
torch.randint(0, num_tags, (2, num_conjugates)), | ||
torch.randint(0, num_tags, (3, num_conjugates)), | ||
]) | ||
|
||
print(decoder.fit(emissions=emissions, tags=tags)) | ||
# tensor([[-6.7424], | ||
# [-5.1288], | ||
# [-2.7283]], grad_fn=<SubBackward0>) | ||
|
||
print(decoder.decode(emissions=emissions)) | ||
# PackedSequence(data=tensor([[2], | ||
# [0], | ||
# [1], | ||
# [0], | ||
# [2], | ||
# [0], | ||
# [2], | ||
# [0], | ||
# [1], | ||
# [2]]), | ||
# batch_sizes=tensor([3, 3, 2, 1, 1]), | ||
# sorted_indices=tensor([0, 2, 1]), | ||
# unsorted_indices=tensor([0, 2, 1])) | ||
|
||
print(decoder.marginals(emissions=emissions)) | ||
# tensor([[[0.1040, 0.1001, 0.7958]], | ||
# | ||
# [[0.5736, 0.0784, 0.3479]], | ||
# | ||
# [[0.0932, 0.8797, 0.0271]], | ||
# | ||
# [[0.6558, 0.0472, 0.2971]], | ||
# | ||
# [[0.2740, 0.1109, 0.6152]], | ||
# | ||
# [[0.4811, 0.2163, 0.3026]], | ||
# | ||
# [[0.2321, 0.3478, 0.4201]], | ||
# | ||
# [[0.4987, 0.1986, 0.3027]], | ||
# | ||
# [[0.2029, 0.5888, 0.2083]], | ||
# | ||
# [[0.2802, 0.2358, 0.4840]]], grad_fn=<AddBackward0>) | ||
``` | ||
`python -m pip torchlatent` | ||
|
||
## Latent Structures | ||
|
||
- [ ] Conditional Random Fields (CRF) | ||
- [x] Conjugated | ||
- [ ] Dynamic Transition Matrix | ||
- [ ] Second-order | ||
- [ ] Variant-order | ||
- [ ] Tree CRF | ||
- [x] Conditional Random Fields (CRF) | ||
- [x] Cocke–Kasami-Younger Algorithm (CKY) | ||
- [ ] Probabilistic Context-free Grammars (CFG) | ||
- [ ] Connectionist Temporal Classification (CTC) | ||
- [ ] Recurrent Neural Network Grammars (RNNG) | ||
- [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) | ||
- [ ] Probabilistic Context-free Grammars (PCFG) | ||
- [ ] Dependency Model with Valence (DMV) | ||
|
||
## Citation | ||
|
||
``` | ||
@misc{wang2020torchlatent, | ||
title={TorchLatent: High Performance Structured Prediction in PyTorch}, | ||
author={Yiran Wang}, | ||
year={2020}, | ||
howpublished = "\url{https://github.com/speedcell4/torchlatent}" | ||
} | ||
``` | ||
- [ ] Autoregressive Decoding (Beam Search) |
Empty file.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
torch | ||
torchrua |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,22 @@ | ||
from setuptools import setup, find_packages | ||
from pathlib import Path | ||
|
||
from setuptools import find_packages, setup | ||
|
||
name = 'torchlatent' | ||
|
||
root_dir = Path(__file__).parent.resolve() | ||
with (root_dir / 'requirements.txt').open(mode='r', encoding='utf-8') as fp: | ||
install_requires = [install_require.strip() for install_require in fp] | ||
|
||
setup( | ||
name=name, | ||
version='0.4.2', | ||
version='0.4.3', | ||
packages=[package for package in find_packages() if package.startswith(name)], | ||
url='https://github.com/speedcell4/torchlatent', | ||
license='MIT', | ||
author='speedcell4', | ||
author_email='[email protected]', | ||
description='High Performance Structured Prediction in PyTorch', | ||
python_requires='>=3.8', | ||
install_requires=[ | ||
'numpy', | ||
'torchrua>=0.4.0', | ||
], | ||
extras_require={ | ||
'dev': [ | ||
'einops', | ||
'pytest', | ||
'hypothesis', | ||
'pytorch-crf', | ||
], | ||
} | ||
install_requires=install_requires, | ||
) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.