package queries import ( "github.com/fiskerinc/cloud-services/pkg/common" s "github.com/fiskerinc/cloud-services/pkg/security" "github.com/fiskerinc/cloud-services/pkg/validator" "github.com/go-pg/pg/v10" "github.com/go-pg/pg/v10/orm" "github.com/pkg/errors" ) const sqlSelectMostRecents = "SELECT c1.* FROM certificates c1 INNER JOIN (SELECT common_name, type, MAX(created_at) created_at FROM certificates WHERE common_name = ? AND type IN (?) GROUP BY common_name, type) c2 ON c1.common_name = c2.common_name AND c1.type = c2.type AND c1.created_at = c2.created_at" type CertificatesInterface interface { Insert(c *common.Certificate) (orm.Result, error) Update(c *common.Certificate) (orm.Result, error) Remove(c *common.Certificate) (orm.Result, error) SelectByCommonName(cn string) ([]common.Certificate, error) SelectBySerial(serial string) (*common.Certificate, error) SelectMostRecent(cn string, certType string) (*common.Certificate, error) SelectMostRecents(cn string, certTypes []string) ([]common.Certificate, error) } type Certificates struct { QueryBase } func (c *Certificates) Insert(certificate *common.Certificate) (orm.Result, error) { enc := s.Encrypt{} encryptor, err := enc.GetEncryptor() if err != nil { return nil, err } certificate.EncryptedKey = encryptor.EncryptChunk([]byte(certificate.EncryptedKey)) return c.resultWithStack(c.GetDBConn().Model(certificate).Insert()) } func (c *Certificates) Update(certificate *common.Certificate) (orm.Result, error) { if certificate.SerialNumber == "" { return nil, errors.WithStack(&validator.FieldError{ ErrorMsg: "Serial required", }) } return c.resultWithStack(c.GetDBConn().Model(certificate).Column("valid").WherePK().Update()) } func (c *Certificates) Remove(certificate *common.Certificate) (orm.Result, error) { if certificate.SerialNumber == "" { return nil, &validator.FieldError{ ErrorMsg: "Serial required", } } return c.resultWithStack(c.GetDBConn().Model(certificate).WherePK().Delete()) } func (c *Certificates) SelectByCommonName(cn string) ([]common.Certificate, error) { certificates := []common.Certificate{} err := c.GetDBConn().Model(&certificates).Where("common_name = ?", cn).Select() if err != nil { return nil, errors.WithStack(err) } for i := range certificates { err = c.decrypt(&certificates[i]) if err != nil { return nil, err } certificates[i].PrivateKey = string(certificates[i].EncryptedKey) } return certificates, err } func (c *Certificates) SelectBySerial(serial string) (*common.Certificate, error) { certificate := common.Certificate{} err := c.GetDBConn().Model(&certificate).Where("serial_number = ?", serial).Select() if err != nil { return nil, errors.WithStack(err) } err = c.decrypt(&certificate) certificate.PrivateKey = string(certificate.EncryptedKey) return &certificate, err } func (c *Certificates) SelectMostRecent(cn string, certType string) (*common.Certificate, error) { cert := common.Certificate{} err := c.GetDBConn().Model(&cert).Where("common_name = ? AND type = ?", cn, certType).Order("created_at desc").Limit(1).Select() if err != nil { return nil, errors.WithStack(err) } err = c.decrypt(&cert) cert.PrivateKey = string(cert.EncryptedKey) return &cert, err } func (c *Certificates) SelectMostRecents(cn string, certTypes []string) ([]common.Certificate, error) { certificates := []common.Certificate{} _, err := c.GetDBConn().Model().Query(&certificates, sqlSelectMostRecents, cn, pg.In(certTypes)) if err != nil { return nil, errors.WithStack(err) } for i := range certificates { err = c.decrypt(&certificates[i]) if err != nil { return nil, err } certificates[i].PrivateKey = string(certificates[i].EncryptedKey) } return certificates, err } func (c *Certificates) decrypt(cert *common.Certificate) error { if cert.EncryptedKey == nil { return nil } enc := s.Encrypt{} encryptor, err := enc.GetEncryptor() if err != nil { return err } pkey, err := encryptor.DecryptChunk([]byte(cert.EncryptedKey)) if err != nil { return err } cert.EncryptedKey = pkey return nil }