Skip to content

Commit

Permalink
Merge pull request #55 from Prateeknandle/rpc-fix
Browse files Browse the repository at this point in the history
fix(relay/client):error handling for rpc's and replacing errgroup with sync.WaitGroup
  • Loading branch information
DelusionalOptimist authored May 22, 2024
2 parents 2496e63 + e68b97b commit 65910d6
Showing 1 changed file with 82 additions and 53 deletions.
135 changes: 82 additions & 53 deletions relay-server/server/relayServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,6 @@ func NewClient(server string) *LogClient {
kg.Warnf("Failed to call WatchLogs (%s)\n err=%s", server, err.Error())
return nil
}
// == //

// set wait group
lc.WgServer, lc.Context = errgroup.WithContext(context.Background())

return lc
}
Expand All @@ -402,30 +398,35 @@ func (lc *LogClient) DoHealthCheck() bool {
}

// WatchMessages Function
func (lc *LogClient) WatchMessages(ctx context.Context) error {
func (lc *LogClient) WatchMessages(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {

defer wg.Done()

var err error

for lc.Running {
var res *pb.Message

if res, err = lc.MsgStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error())

}
select {
case MsgBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.MsgStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error())
return
}

select {
case MsgBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Message buffer
}
}
}

kg.Print("Stopped watching messages from " + lc.Server)

return nil
}

// AddMsgFromBuffChan Adds Msg from MsgBufferChannel into MsgStructs
Expand Down Expand Up @@ -461,30 +462,35 @@ func (rs *RelayServer) AddMsgFromBuffChan() {
}

// WatchAlerts Function
func (lc *LogClient) WatchAlerts(ctx context.Context) error {
func (lc *LogClient) WatchAlerts(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {

defer wg.Done()

var err error

for lc.Running {
var res *pb.Alert

if res, err = lc.AlertStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a alert (%s) %s", lc.Server, err.Error())
}

select {
case AlertBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.AlertStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive an alert (%s) %s", lc.Server, err.Error())
return
}

select {
case AlertBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Alert buffer
}
}
}

kg.Print("Stopped watching alerts from " + lc.Server)

return nil
}

// AddAlertFromBuffChan Adds ALert from AlertBufferChannel into AlertStructs
Expand Down Expand Up @@ -520,30 +526,34 @@ func (rs *RelayServer) AddAlertFromBuffChan() {
}

// WatchLogs Function
func (lc *LogClient) WatchLogs(ctx context.Context) error {
func (lc *LogClient) WatchLogs(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {
defer wg.Done()

var err error

for lc.Running {
var res *pb.Log

if res, err = lc.LogStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error())
}

select {
case LogBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.LogStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error())
return
}

select {
case LogBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Log buffer
}
}
}

kg.Print("Stopped watching logs from " + lc.Server)

return nil
}

// AddLogFromBuffChan Adds Log from LogBufferChannel into LogStructs
Expand Down Expand Up @@ -744,26 +754,45 @@ func connectToKubeArmor(nodeIP, port string) error {
}
kg.Printf("Checked the liveness of KubeArmor's gRPC service (%s)", server)

// watch messages
client.WgServer.Go(func() error {
return client.WatchMessages(client.Context)
})
var wg sync.WaitGroup
stop := make(chan struct{})
errCh := make(chan error, 1)

// Start watching messages
wg.Add(1)
go func() {
client.WatchMessages(&wg, stop, errCh)
}()
kg.Print("Started to watch messages from " + server)

// watch alerts
client.WgServer.Go(func() error {
return client.WatchAlerts(client.Context)
})
// Start watching alerts
wg.Add(1)
go func() {
client.WatchAlerts(&wg, stop, errCh)
}()
kg.Print("Started to watch alerts from " + server)

// watch logs
client.WgServer.Go(func() error {
return client.WatchLogs(client.Context)
})
// Start watching logs
wg.Add(1)
go func() {
client.WatchLogs(&wg, stop, errCh)
}()
kg.Print("Started to watch logs from " + server)

if err := client.WgServer.Wait(); err != nil {
// Wait for an error or all goroutines to finish
select {
case err := <-errCh:
close(stop) // Stop other goroutines
kg.Warn(err.Error())
case <-func() chan struct{} {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
return done
}():
// All goroutines finished without error
}

if err := client.DestroyClient(); err != nil {
Expand Down

0 comments on commit 65910d6

Please sign in to comment.