Skip to content

Commit

Permalink
runner: Add a helper to read and downcast a message
Browse files Browse the repository at this point in the history
We can't always use this because sometimes (particular in TLS 1.2), you
have to account for optional messages, but often we know exactly which
message we're expecting.

Change-Id: I4f6f59111fbf3e5f8a8fefa35802def9b2029196
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72949
Reviewed-by: Nick Harper <[email protected]>
Commit-Queue: David Benjamin <[email protected]>
  • Loading branch information
davidben authored and Boringssl LUCI CQ committed Nov 15, 2024
1 parent 6254482 commit f721d41
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 151 deletions.
13 changes: 13 additions & 0 deletions ssl/test/runner/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,19 @@ func (c *Conn) readHandshake() (any, error) {
return m, nil
}

func readHandshakeType[T any](c *Conn) (*T, error) {
m, err := c.readHandshake()
if err != nil {
return nil, err
}
mType, ok := m.(*T)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(mType, m)
}
return mType, nil
}

// skipPacket processes all the DTLS records in packet. It updates
// sequence number expectations but otherwise ignores them.
func (c *Conn) skipPacket(packet []byte) error {
Expand Down
58 changes: 8 additions & 50 deletions ssl/test/runner/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,16 +1152,10 @@ func (hs *clientHandshakeState) doTLS13Handshake(msg any) error {
return err
}

msg, err := c.readHandshake()
encryptedExtensions, err := readHandshakeType[encryptedExtensionsMsg](c)
if err != nil {
return err
}

encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
hs.writeServerHash(encryptedExtensions.marshal())

if !bytes.Equal(encryptedExtensions.extensions.echRetryConfigs, c.config.Bugs.ExpectECHRetryConfigs) {
Expand Down Expand Up @@ -1277,15 +1271,10 @@ func (hs *clientHandshakeState) doTLS13Handshake(msg any) error {
c.ocspResponse = certMsg.certificates[0].ocspResponse
c.sctList = certMsg.certificates[0].sctList

msg, err = c.readHandshake()
certVerifyMsg, err := readHandshakeType[certificateVerifyMsg](c)
if err != nil {
return err
}
certVerifyMsg, ok := msg.(*certificateVerifyMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certVerifyMsg, msg)
}

c.peerSignatureAlgorithm = certVerifyMsg.signatureAlgorithm
input := hs.finishedHash.certificateVerifyInput(serverCertificateVerifyContextTLS13)
Expand All @@ -1301,16 +1290,10 @@ func (hs *clientHandshakeState) doTLS13Handshake(msg any) error {
hs.writeServerHash(certVerifyMsg.marshal())
}

msg, err = c.readHandshake()
serverFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverFinished, msg)
}

verify := hs.finishedHash.serverSum(serverHandshakeTrafficSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
Expand Down Expand Up @@ -1341,14 +1324,10 @@ func (hs *clientHandshakeState) doTLS13Handshake(msg any) error {
// BoringSSL will always send two tickets half-RTT when
// negotiating 0-RTT.
for i := 0; i < shimConfig.HalfRTTTickets; i++ {
msg, err := c.readHandshake()
newSessionTicket, err := readHandshakeType[newSessionTicketMsg](c)
if err != nil {
return fmt.Errorf("tls: error reading half-RTT ticket: %s", err)
}
newSessionTicket, ok := msg.(*newSessionTicketMsg)
if !ok {
return errors.New("tls: expected half-RTT ticket")
}
// Defer processing until the resumption secret is computed.
deferredTickets = append(deferredTickets, newSessionTicket)
}
Expand Down Expand Up @@ -1622,16 +1601,10 @@ func (hs *clientHandshakeState) doFullHandshake() error {

var leaf *x509.Certificate
if hs.suite.flags&suitePSK == 0 {
msg, err := c.readHandshake()
certMsg, err := readHandshakeType[certificateMsg](c)
if err != nil {
return err
}

certMsg, ok := msg.(*certificateMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
hs.writeServerHash(certMsg.marshal())

if err := hs.verifyCertificates(certMsg); err != nil {
Expand All @@ -1641,15 +1614,10 @@ func (hs *clientHandshakeState) doFullHandshake() error {
}

if hs.serverHello.extensions.ocspStapling {
msg, err := c.readHandshake()
cs, err := readHandshakeType[certificateStatusMsg](c)
if err != nil {
return err
}
cs, ok := msg.(*certificateStatusMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(cs, msg)
}
hs.writeServerHash(cs.marshal())

if cs.statusType == statusTypeOCSP {
Expand Down Expand Up @@ -2176,15 +2144,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error {
return err
}

msg, err := c.readHandshake()
serverFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
serverFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverFinished, msg)
}

if c.config.Bugs.EarlyChangeCipherSpec == 0 {
verify := hs.finishedHash.serverSum(hs.masterSecret)
Expand Down Expand Up @@ -2233,15 +2196,10 @@ func (hs *clientHandshakeState) readSessionTicket() error {
return errors.New("tls: received unexpected NewSessionTicket")
}

msg, err := c.readHandshake()
sessionTicketMsg, err := readHandshakeType[newSessionTicketMsg](c)
if err != nil {
return err
}
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}

if c.config.Bugs.ExpectNoNonEmptyNewSessionTicket && len(sessionTicketMsg.ticket) != 0 {
return errors.New("tls: received unexpected non-empty NewSessionTicket")
Expand Down
Loading

0 comments on commit f721d41

Please sign in to comment.