-
Notifications
You must be signed in to change notification settings - Fork 0
/
magiclink.go
291 lines (261 loc) · 9.17 KB
/
magiclink.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
package gomagiclink
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base32"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"time"
"github.com/oklog/ulid/v2"
)
type RecordWithID interface {
GetID() ulid.ULID
}
type RecordWithKeyName interface {
GetKeyName() string
}
// When a new storage provider is created, it implements this interface.
// See the provided storage provided in the `storage` package.
type UserAuthDatabase interface {
UserExistsByEmail(email string) bool
StoreUser(user *AuthUserRecord) error
GetUserById(id ulid.ULID) (*AuthUserRecord, error)
GetUserByEmail(email string) (*AuthUserRecord, error)
GetUserCount() (int, error) // Slow
UsersExist() (bool, error) // Fast
}
const challengeSignature = "9"
const sessionIdSignature = "S"
const saltLength = 8
var ErrUserAlreadyExists = errors.New("user already exists")
var ErrUserNotFound = errors.New("user not found")
var ErrUserDisabled = errors.New("user disabled")
var ErrSecretKeyTooShort = errors.New("secret Key too short (min 16 bytes)")
var ErrInvalidChallenge = errors.New("invalid challenge")
var ErrBrokenChallenge = errors.New("broken challenge")
var ErrExpiredChallenge = errors.New("expired challenge")
var ErrInvalidSessionId = errors.New("invalid session id")
var ErrBrokenSessionId = errors.New("broken session id")
var ErrExpiredSessionId = errors.New("expired session id")
// All functionalities needed to implement the Magic Link login system is available
// through the AuthMagicLinkController.
type AuthMagicLinkController struct {
secretKeyHash []byte
challengeExpDuration time.Duration
sessionExpDuration time.Duration
db UserAuthDatabase
}
// NewAuthMagicLinkController configures and creates a new instance of the AuthMagicLinkController.
// The secretKey needs to be kept safe. To provide your own storage mechanism for the magic
// link data, implement the UserAuthDatabase interface. There are file system and SQL database
// implementations provided.
func NewAuthMagicLinkController(secretKey []byte, challengeExpDuration time.Duration, sessionExpDuration time.Duration, db UserAuthDatabase) (mlc *AuthMagicLinkController, err error) {
if len(secretKey) < 16 {
return nil, ErrSecretKeyTooShort
}
keyHash := sha256.Sum256(secretKey)
return &AuthMagicLinkController{
secretKeyHash: keyHash[:],
challengeExpDuration: challengeExpDuration,
sessionExpDuration: sessionExpDuration,
db: db,
}, nil
}
func (mlc *AuthMagicLinkController) makeHMAC(payload []byte) []byte {
mac := hmac.New(sha256.New, mlc.secretKeyHash)
mac.Write(payload)
return mac.Sum(nil)
}
func (mlc *AuthMagicLinkController) GetUserByEmail(email string) (*AuthUserRecord, error) {
return mlc.db.GetUserByEmail(email)
}
func (mlc *AuthMagicLinkController) StoreUser(user *AuthUserRecord) error {
return mlc.db.StoreUser(user)
}
func (mlc *AuthMagicLinkController) UserExistsByEmail(email string) bool {
return mlc.db.UserExistsByEmail(email)
}
func (mlc *AuthMagicLinkController) GetUserCount() (int, error) {
return mlc.db.GetUserCount()
}
func (mlc *AuthMagicLinkController) UsersExist() (bool, error) {
return mlc.db.UsersExist()
}
// GenerateChallenge creates a challenge string to be used for constructing the magic link.
// This challenge string needs to be verified by VerifyChallenge()
func (mlc *AuthMagicLinkController) GenerateChallenge(email string) (challenge string, err error) {
// Challenge is in the format:
// SALT-EMAIL-EXPTIME-HMAC(SALT || EMAIL || EXPTIME, secredKeyHash)
email = NormalizeEmail(email)
salt := make([]byte, saltLength)
_, err = rand.Read(salt)
if err != nil {
return
}
expTime := time.Now().Add(mlc.challengeExpDuration).Unix()
hmac := mlc.makeHMAC(slices.Concat(salt, []byte{0}, []byte(email), []byte{0}, []byte(strconv.Itoa(int(expTime)))))
challenge = fmt.Sprintf("%s%s-%s-%d-%s", challengeSignature, encodeToString(salt), encodeToString([]byte(email)), expTime, encodeToString(hmac))
return challenge, nil
}
// VerifyChallenge verifies the challenge string generated by GenerateChallenge(),
// and returns the AuthUserRecord corresponding to the user for which the challenge
// was created (identifying them by their email address).
func (mlc *AuthMagicLinkController) VerifyChallenge(challenge string) (user *AuthUserRecord, err error) {
if !strings.HasPrefix(challenge, challengeSignature) {
return nil, ErrInvalidChallenge
}
challenge = challenge[len(challengeSignature):]
parts := strings.Split(challenge, "-")
if len(parts) != 4 {
return nil, ErrInvalidChallenge
}
salt, err := decodeFromString(parts[0])
if err != nil {
return nil, ErrInvalidChallenge
}
email, err := decodeFromString(parts[1])
if err != nil {
return nil, ErrInvalidChallenge
}
expTime, err := strconv.Atoi(parts[2])
if err != nil {
return nil, ErrInvalidChallenge
}
if expTime < int(time.Now().Unix()) {
return nil, ErrExpiredChallenge
}
hmac1, err := decodeFromString(parts[3])
if err != nil {
return nil, ErrInvalidChallenge
}
hmac2 := mlc.makeHMAC(slices.Concat(salt, []byte{0}, []byte(email), []byte{0}, []byte(strconv.Itoa(int(expTime)))))
if !hmac.Equal(hmac1, hmac2) {
return nil, ErrBrokenChallenge
}
// We've verified the challenge, so assume the user is real.
// Now either create a new AuthUserRecord or load an existing one.
user, err = mlc.db.GetUserByEmail(string(email))
if err != nil {
if err == ErrUserNotFound {
user, err = NewAuthUserRecord(string(email))
}
}
if user != nil {
if !user.Enabled {
return nil, ErrUserDisabled
}
user.RecentLoginTime = time.Now()
}
return
}
// GenerateSessionId generates a session id suitable for using as a cookie
// in a web app.
func (mlc *AuthMagicLinkController) GenerateSessionId(user *AuthUserRecord) (sessionId string, err error) {
// Session ID is in the format:
// SALT-USER_ID-EXPTIME-HMAC(SALT || USER_ID || EXPTIME, secretKeyHash)
salt := make([]byte, saltLength)
_, err = rand.Read(salt)
if err != nil {
return
}
userId := user.ID.String()
expTime := 0
if mlc.sessionExpDuration > 0 {
expTime = int(time.Now().Add(mlc.sessionExpDuration).Unix())
}
expTimeStr := strconv.Itoa(expTime)
hmac := mlc.makeHMAC(slices.Concat(salt, []byte{0}, user.ID.Bytes(), []byte{0}, []byte(expTimeStr)))
return fmt.Sprintf("%s%s-%s-%s-%s", sessionIdSignature, encodeToString(salt), userId, expTimeStr, encodeToString(hmac)), nil
}
// VerifySessionId verifies the session ID generated by GenerateSessionId() and if it's valid,
// returns the AuthUserRecord of the associated user.
func (mlc *AuthMagicLinkController) VerifySessionId(sessionId string) (user *AuthUserRecord, err error) {
if !strings.HasPrefix(sessionId, sessionIdSignature) {
return nil, ErrInvalidSessionId
}
sessionId = sessionId[len(sessionIdSignature):]
parts := strings.Split(sessionId, "-")
if len(parts) != 4 {
return nil, ErrInvalidSessionId
}
salt, err := decodeFromString(parts[0])
if err != nil {
return nil, ErrInvalidSessionId
}
userId, err := ulid.ParseStrict(parts[1])
if err != nil {
return nil, ErrInvalidSessionId
}
expTime, err := strconv.Atoi(parts[2])
if err != nil {
return nil, ErrInvalidSessionId
}
if expTime < int(time.Now().Unix()) {
return nil, ErrExpiredSessionId
}
hmac1, err := decodeFromString(parts[3])
if err != nil {
return nil, ErrInvalidSessionId
}
hmac2 := mlc.makeHMAC(slices.Concat(salt, []byte{0}, userId.Bytes(), []byte{0}, []byte(parts[2])))
if !hmac.Equal(hmac1, hmac2) {
return nil, ErrBrokenSessionId
}
// Now we're sure the session Id is validated, so the userId should be valid
user, err = mlc.db.GetUserById(userId)
if !user.Enabled {
return nil, ErrUserDisabled
}
user.RecentLoginTime = time.Now()
return
}
// AuthUser represents user data
type AuthUserRecord struct {
ID ulid.ULID `json:"id"` // Unique identifier
Enabled bool `json:"enabled"`
Email string `json:"email"` // Also must be unique
AccessLevel int `json:"access_level"`
FirstLoginTime time.Time `json:"first_login_time"`
RecentLoginTime time.Time `json:"recent_login_time"`
CustomData any `json:"custom_data"` // Apps can attach any kind of custom data to the user record
}
// NewAuthUserRecords constructs a new AuthUserRecord. This function isn't normally
// directly called by the users of this package.
func NewAuthUserRecord(email string) (aur *AuthUserRecord, err error) {
now := time.Now()
aur = &AuthUserRecord{
ID: ulid.Make(),
Email: NormalizeEmail(email),
Enabled: true,
FirstLoginTime: now,
RecentLoginTime: now,
CustomData: nil,
}
return aur, nil
}
// Returns the user ID.
func (aur *AuthUserRecord) GetID() ulid.ULID {
if IsZeroULID(aur.ID) {
aur.ID = ulid.Make()
}
return aur.ID
}
// Returns the Key name suitable for key-value databases.
func (aur *AuthUserRecord) GetKeyName() string {
if IsZeroULID(aur.ID) {
aur.ID = ulid.Make()
}
return fmt.Sprintf("_%s_%s", aur.ID.String(), aur.Email)
}
// Binary-string encoding
func encodeToString(b []byte) string {
return strings.TrimRight(base32.StdEncoding.EncodeToString(b), "=")
}
func decodeFromString(s string) ([]byte, error) {
s = s + strings.Repeat("=", 8-(len(s)%8))
return base32.StdEncoding.DecodeString(s)
}