Files
cloud-services/pkg/security/encryptedstream.go

99 lines
2.4 KiB
Go

package security
import (
"fmt"
"math"
"strconv"
"github.com/pkg/errors"
)
type IEncryptedStream interface {
Write(raw_data []byte) []byte
Read(raw_data []byte) ([]byte, error)
}
const blockSize = 4096
const headerSize = 16
type EncryptedStream struct {
gcm_encrypter IEncryptor
header []byte
header_applied bool
}
type StreamOption func(EncryptedStream) (EncryptedStream, error)
func NewEncryptedStream(encryptor IEncryptor, options ...StreamOption) (IEncryptedStream, error) {
var err error
stream := EncryptedStream{gcm_encrypter: encryptor, header: nil, header_applied: false}
for _, option := range options {
stream, err = option(stream)
if err != nil {
return nil, err
}
}
return &stream, nil
}
func WithUniqueId(uniqueid []byte) StreamOption {
return func(stream EncryptedStream) (EncryptedStream, error) {
if len(uniqueid) == headerSize {
stream.header = uniqueid
stream.header_applied = false
return stream, nil
}
return stream, errors.New("invalid file id - must be 16 bytes")
}
}
func (s *EncryptedStream) Write(rawData []byte) []byte {
var length = len(rawData)
highWatermark := 0
byteStream := make([]byte, 0)
index := 0
if !s.header_applied {
byteStream = append(byteStream, s.header...)
s.header_applied = true
}
for index < length {
highWatermark = int(math.Min(float64(length-index), float64(blockSize)))
slice := s.gcm_encrypter.EncryptChunk(rawData[index:highWatermark])
chunk_size := fmt.Sprintf("%04x", len(slice))
byteStream = append(byteStream, chunk_size...)
byteStream = append(byteStream, slice...)
index += highWatermark
}
return byteStream
}
func (s *EncryptedStream) Read(rawData []byte) ([]byte, error) {
var length int64 = int64(len(rawData))
byteStream := make([]byte, 0)
if !s.header_applied && length < headerSize {
return nil, errors.New("invalid stream")
}
var index int64 = 0
// read header
if !s.header_applied {
s.header = rawData[0:headerSize]
s.header_applied = true
index = headerSize
}
for index < length {
nextBlockSize := "0x" + string(rawData[index:index+4])
byte_slice_length, err := strconv.ParseInt(nextBlockSize, 0, 64)
if err != nil {
return nil, err
}
index += 4 // move index by 4 to actual data
slice, err := s.gcm_encrypter.DecryptChunk(rawData[index:(index + byte_slice_length)])
if err != nil {
return nil, err
}
byteStream = append(byteStream, slice...)
index += byte_slice_length
}
return byteStream, nil
}