Initial cloud-services repo - gateway service + pkg modules
This commit is contained in:
159
pkg/utils/multipart_parser.go
Normal file
159
pkg/utils/multipart_parser.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const chunkSize int = 4096
|
||||
|
||||
// FileInfo used to send part filename and content type
|
||||
type FileInfo struct {
|
||||
FileID string
|
||||
Filename string
|
||||
ContentType string
|
||||
Part *multipart.Part
|
||||
FileSize uint64
|
||||
OrigFileSize uint64
|
||||
}
|
||||
|
||||
// FormParser type func that handles parsing of form fields values
|
||||
type FormParser func(part *multipart.Part) error
|
||||
type FileEncryptor func(input []byte) []byte
|
||||
|
||||
// FindFilePart finds file part of multipart upload
|
||||
func FindFilePart(r *http.Request, boundary string, fileFieldName string, formParser FormParser) (*FileInfo, error) {
|
||||
result := &FileInfo{}
|
||||
mr := multipart.NewReader(r.Body, boundary)
|
||||
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return result, nil
|
||||
} else if err != nil {
|
||||
return result, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if part.FormName() != fileFieldName || len(part.FileName()) == 0 {
|
||||
if formParser != nil {
|
||||
err = formParser(part)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
}
|
||||
part.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
result.Filename = part.FileName()
|
||||
result.ContentType = GetHTTPHeader(part.Header, "Content-Type", "")
|
||||
len, _ := strconv.ParseInt(GetHTTPHeader(part.Header, "Content-Length", "0"), 10, 32)
|
||||
result.OrigFileSize = uint64(len)
|
||||
result.Part = part
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkFilePart chunks file part into pipe writer
|
||||
func ChunkFilePart(writer *io.PipeWriter, info *FileInfo, encryptor FileEncryptor) {
|
||||
var err error
|
||||
var origFileSize uint64
|
||||
var fileSize uint64
|
||||
defer writer.Close()
|
||||
|
||||
n := 0
|
||||
buf := make([]byte, chunkSize)
|
||||
|
||||
for {
|
||||
n, err = io.ReadFull(info.Part, buf)
|
||||
if n > 0 {
|
||||
if encryptor != nil {
|
||||
out := encryptor(buf[:n])
|
||||
writer.Write(out)
|
||||
fileSize += uint64(len(out))
|
||||
} else {
|
||||
writer.Write(buf[:n])
|
||||
fileSize += uint64(n)
|
||||
}
|
||||
origFileSize += uint64(n)
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
if info != nil {
|
||||
info.FileSize = fileSize
|
||||
info.OrigFileSize = origFileSize
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PartReadAll returns entire part value. Maxchars checks that max characters is not exceeded
|
||||
func PartReadAll(part *multipart.Part, maxchars int) (string, error) {
|
||||
var err error
|
||||
n := 0
|
||||
buf := make([]byte, chunkSize)
|
||||
result := []byte{}
|
||||
|
||||
for {
|
||||
n, err = part.Read(buf)
|
||||
result = append(result, buf[:n]...)
|
||||
if len(result) > maxchars {
|
||||
return "", errors.Errorf("%s exceeded %d characters", part.FormName(), maxchars)
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return string(result), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func HashMultipartFile(file multipart.File) ([]byte, error) {
|
||||
var err error
|
||||
n := 0
|
||||
buf := make([]byte, chunkSize)
|
||||
digest := sha256.New()
|
||||
|
||||
for {
|
||||
n, err = file.Read(buf)
|
||||
if n > 0 {
|
||||
_, err := digest.Write(buf[:n])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
}
|
||||
sum := digest.Sum(nil)
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
func PartReadInt64(part *multipart.Part, maxchars int, bitSize int) (int64, error) {
|
||||
value, err := PartReadAll(part, maxchars)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return strconv.ParseInt(value, 10, bitSize)
|
||||
}
|
||||
|
||||
func PartReadUInt64(part *multipart.Part, maxchars int, bitSize int) (uint64, error) {
|
||||
value, err := PartReadAll(part, maxchars)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return strconv.ParseUint(value, 10, bitSize)
|
||||
}
|
||||
|
||||
func PartReadInt(part *multipart.Part, maxchars int) (int, error) {
|
||||
value, err := PartReadInt64(part, maxchars, 32)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
return int(value), nil
|
||||
}
|
||||
Reference in New Issue
Block a user