Files
cloud-services/services/gateway/websocket/session.go

414 lines
10 KiB
Go

package websocket
import (
"compress/flate"
"context"
"encoding/json"
"fmt"
"gateway/sloppy"
"io"
"net"
"net/http"
"strings"
"time"
"fiskerinc.com/modules/grpc/kafka_grpc"
"fiskerinc.com/modules/common"
"fiskerinc.com/modules/kafka"
"fiskerinc.com/modules/logger"
"fiskerinc.com/modules/utils"
"fiskerinc.com/modules/utils/envtool"
"fiskerinc.com/modules/validator"
"google.golang.org/protobuf/proto"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
"github.com/pkg/errors"
)
var deadline = time.Duration(envtool.GetEnvInt("WS_TIMEOUT", 30)) * time.Second
// SessionInterface provides methods for connection
type SessionInterface interface {
Authenticate() error
Key() string
SendMsgToClient(message []byte) error
Receive() ([]byte, ws.OpCode, error)
Listen(context.Context, kafka.ProducerInterface) error
Load(kafka.ProducerInterface) error
Teardown(kafka.ProducerInterface) error
Close() error
GetWebsocket() net.Conn
GetIP() string
GetType() string
IsDevice(device common.Device) bool
GetID() string
GetUUID() int64
GetVIN() (string)
SkipTeardown(skip bool)
}
// NewSecureSession creates session w/ websocket based off user-agent
// given in HTTP request
//
// ex: "Fisker Ocean T.Rex 1.2.3.4 abc123" - T.Rex
// ex: "HMI 2.0.0.0" - HMI
func NewSecureSession(w http.ResponseWriter, r *http.Request) (SessionInterface, error) {
var s SessionInterface
// HERE, get the vin and block the request
vin, err := utils.ParseVINFromRequest(r)
if err != nil {
logger.At(logger.Error(), "no vin from request", "conn").Send()
return s, err
}
ok := validator.ValidateVINSimple(vin)
if !ok {
logger.Error().Str("type", "conn").Str("VIN", vin).Msg("NewSecureSession failed to validate vin")
return s, errors.Errorf("%s failed to validate VIN", vin)
}
vin = strings.ToUpper(vin)
if !sloppy.GetVINBlocker().IsVINAllowed(vin){
return s, errors.Errorf("%s is not an allowed VIN, please contact support", vin)
}
device, version := ParseDeviceAndVersionFromRequest(r)
switch device {
case "fisker":
logger.At(logger.Info(), "1:"+vin, "conn")
iccid, err := ParseICCIDFromRequest(r)
if err != nil {
logger.At(logger.Warn(), "1:"+vin, "conn").
Err(errors.WithMessagef(err, "failed to parse ICCID from request %s", vin)).Send()
}
s, err = NewTRexSession(w, r, vin, version, iccid)
if err != nil {
logger.At(logger.Warn(), "1:"+vin, "conn").
Err(errors.WithMessagef(err, "failed to create Trex session %s", vin)).Send()
return s, err
}
logger.At(logger.Info(), "1:"+vin, "conn").Send()
case "hmi":
s, err = NewHMISession(w, r, vin, version)
if err != nil {
logger.At(logger.Warn(), "2:"+vin, "conn").
Err(errors.WithMessagef(err, "failed to create HMI session %s", vin)).Send()
return s, err
}
logger.At(logger.Info(), "2:"+vin, "conn").Send()
default:
return s, ErrFailedToLoad
}
return s, nil
}
// NewInsecureSession creates session w/ websocket based off user-agent
// given in HTTP request
//
// ex: "Mobile 1.2.3.4" - Mobile
func NewInsecureSession(w http.ResponseWriter, r *http.Request) (SessionInterface, error) {
var s SessionInterface
var err error
device, version := ParseDeviceAndVersionFromRequest(r)
switch device {
case "mobile", "android", "ios":
s, err = NewMobileSession(w, r, version)
if err != nil {
return s, err
}
logger.At(logger.Info(), "3: "+s.GetID(), "conn").Send()
default:
return s, ErrFailedToLoad
}
return s, nil
}
// NewSession is used when device is unknown
func NewSession(w http.ResponseWriter, r *http.Request) (SessionInterface, error) {
var s SessionInterface
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
return s, errors.WithStack(err)
}
return &Session{
Websocket: conn,
Type: common.Unknown,
epoch: time.Now().UnixNano(),
}, nil
}
// Session contains websocket info
type Session struct {
Websocket net.Conn
Type common.Device
ID string // used for key generation to kafka
Version string
epoch int64
skipteardown bool
}
// Authenticate returns id if proper authentication, else returns error
func (s *Session) Authenticate() error {
msg, _, err := s.Receive()
if err != nil {
return err
}
var ae AuthEvent
err = json.Unmarshal(msg, &ae)
if err != nil {
return errors.WithStack(err)
}
authenticated, err := AuthenticateRequest(ae)
if err != nil {
return err
} else if !authenticated {
return errors.New("failed authentication")
}
s.ID = ae.Key
return nil
}
// Key generates key based on type of session and ID
func (s *Session) Key() string {
if s.Type == common.Unknown {
return s.ID
}
return s.Type.Key(s.ID)
}
// SendMsgToClient: Send a message to client
func (s *Session) SendMsgToClient(message []byte) error {
vin := s.GetVIN()
logger.Debug().Str("type", s.GetType()).Str("VIN", vin).Int64("SessionID", s.GetUUID()).Str("value", string(message)).Msg("SendMsgToClient")
err := wsutil.WriteServerMessage(s.Websocket, ws.OpText, message)
if err != nil {
err = errors.WithStack(err)
}
return err
}
func (s *Session) extendDeadline() error {
return s.Websocket.SetDeadline(time.Now().Add(deadline))
}
func (s *Session) receive(postFrame func() error) ([]byte, ws.OpCode, error) {
var (
err error
h ws.Header
msg wsflate.MessageState
)
// Using nil as a source io.Reader since we will Reset() it in the loop
// below.
fr := wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
controlHandler := wsutil.ControlFrameHandler(s.Websocket, ws.StateServerSide)
rd := wsutil.Reader{
Source: s.Websocket,
State: ws.StateServerSide | ws.StateExtended,
OnIntermediate: controlHandler,
Extensions: []wsutil.RecvExtension{&msg},
}
for {
h, err = rd.NextFrame()
if err != nil {
return nil, h.OpCode, err
}
if postFrame != nil {
err = postFrame()
if err != nil {
return nil, 0, err
}
}
if h.OpCode.IsControl() {
if err := controlHandler(h, &rd); err != nil {
return nil, h.OpCode, err
}
continue
}
var src io.Reader = &rd
if msg.IsCompressed() {
fr.Reset(&rd)
src = fr
}
data, err := io.ReadAll(src)
if err != nil {
return nil, h.OpCode, err
}
return data, h.OpCode, err
}
}
func (s *Session) Receive() ([]byte, ws.OpCode, error) {
return s.receive(nil)
}
// Listen to websocket session and use handler upon message received
func (s *Session) Listen(ctx context.Context, producer kafka.ProducerInterface) error {
span, _ := tracer.StartSpanFromContext(ctx, "listen")
defer span.Finish()
key := s.Key()
for {
msg, op, err := s.Receive()
if op == ws.OpClose {
logger.At(logger.Info(), "Socket:Listen::EOF closing session ", key).Msg("OpClose")
return nil
} else if err != nil {
logger.At(logger.Error(), "Socket:Listen::err during receiving session ", key).Err(err).Send()
return err
}
err = s.Route(producer, msg)
if err != nil {
logger.At(logger.Warn(), "Socket:Listen:: failed route session ", key).Err(err).Send()
}
}
}
// Route messages
// - this allows other structs to override the behavior of messages received
func (s *Session) Route(producer kafka.ProducerInterface, data []byte) error {
var e common.EventRawJSON
err := e.Unmarshal(data)
if err != nil {
return errors.WithStack(err)
}
key := s.Key()
return producer.Produce(e.Topic, key, e.Payload, nil)
}
// Load the session - distributes messages to system notifying of new connection
func (s *Session) Load(producer kafka.ProducerInterface) error {
key := s.Key()
logger.At(logger.Info(), "Session::Load connection start notification", key).
Msgf("session.Load %s", key)
payload := kafka_grpc.GRPC_DepotPayload{
Handler: "init",
}
binaryPayload, _ := proto.Marshal(&payload)
err := producer.ProduceBinary(kafka.DepotServiceGRPCKafka, key, binaryPayload, nil)
return err
}
// Teardown the session - distributes messages to system notifying of removed connection
func (s *Session) Teardown(producer kafka.ProducerInterface) error {
// Go to send del message to depot service if connection was a duplicate
if s.skipteardown {
return nil
}
key := s.Key()
logger.At(logger.Debug(), "Session::Teardown: Notify services ", key).
Msgf("session.Teardown %s", key)
payload := kafka_grpc.GRPC_DepotPayload{
Handler: "del",
}
binaryPayload, _ := proto.Marshal(&payload)
err := producer.ProduceBinary(kafka.DepotServiceGRPCKafka, key, binaryPayload, nil)
return err
}
// Close the session
func (s *Session) Close() error {
key := s.Key()
logger.At(logger.Debug(), "Session:Close connection for ", key)
return s.Websocket.Close()
}
// GetWebsocket returns session's websocket
func (s *Session) GetWebsocket() net.Conn {
return s.Websocket
}
// GetIP returns session's websocket's IP
func (s *Session) GetIP() string {
return s.Websocket.RemoteAddr().String()
}
// GetType returns Device type in string form
func (s *Session) GetType() string {
return s.Type.String()
}
func (s *Session) IsDevice(device common.Device) bool {
return s.Type == device
}
// GetID returns ID of session (not to be mistaken with key)
func (s *Session) GetID() string {
return s.ID
}
// GetUUID returns a unique identifier for the session
func (s *Session) GetUUID() int64 {
return s.epoch
}
func (s *Session) GetVIN() (vin string) {
// For somereason code was changed to do some kind of parsing from session, but VIN is added directly
return s.ID
}
func (s *Session) SkipTeardown(skip bool) {
s.skipteardown = skip
}
func PrintRequest(r *http.Request) string {
// Create return string
var request []string
// Add the request string
url := fmt.Sprintf("%v %v %v", r.Method, r.URL, r.Proto)
request = append(request, url)
// Add the host
request = append(request, fmt.Sprintf("Host: %v", r.Host))
// Loop through headers
for name, headers := range r.Header {
name = strings.ToLower(name)
for _, h := range headers {
request = append(request, fmt.Sprintf("%v: %v", name, h))
}
}
// If this is a POST, add post data
if r.Method == "POST" {
r.ParseForm()
request = append(request, "\n")
request = append(request, r.Form.Encode())
}
// Return the request as a string
return strings.Join(request, "\n")
}