better separation of concerns

This commit is contained in:
Jesse Duffield 2020-06-14 13:44:11 +10:00
parent 62f26a105e
commit 17a9719f3b
6 changed files with 208 additions and 91 deletions

20
main.go
View file

@ -21,9 +21,27 @@ func main() {
} else { } else {
dir = os.Args[2] dir = os.Args[2]
} }
if err := commands.Bind(dir); err != nil { paths, err := commands.GetHorcruxPathsInDir(dir)
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
overwrite := false
for {
if err := commands.Bind(paths, "", overwrite); err != nil {
if err != os.ErrExist {
log.Fatal(err)
}
overwriteResponse := commands.Prompt("A file already exists at destination. Overwrite? (Y/N):")
if overwriteResponse == "Y" || overwriteResponse == "y" || overwriteResponse == "yes" {
overwrite = true
} else {
log.Fatal("You have chosen not to overwrite the file. Cancelling.")
}
} else {
break
}
}
return return
} }

View file

@ -1,71 +1,125 @@
package commands package commands
import ( import (
"bufio"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strings"
"github.com/jesseduffield/horcrux/pkg/multiplexing" "github.com/jesseduffield/horcrux/pkg/multiplexing"
"github.com/jesseduffield/horcrux/pkg/shamir" "github.com/jesseduffield/horcrux/pkg/shamir"
) )
func Bind(dir string) error { func GetHorcruxPathsInDir(dir string) ([]string, error) {
files, err := ioutil.ReadDir(dir) files, err := ioutil.ReadDir(dir)
if err != nil { if err != nil {
return err return nil, err
} }
filenames := []string{} paths := []string{}
for _, file := range files { for _, file := range files {
if filepath.Ext(file.Name()) == ".horcrux" { if filepath.Ext(file.Name()) == ".horcrux" {
filenames = append(filenames, file.Name()) paths = append(paths, file.Name())
} }
} }
headers := []horcruxHeader{} return paths, nil
horcruxFiles := []*os.File{} }
for _, filename := range filenames { type byIndex []Horcrux
file, err := os.Open(filename)
defer file.Close() func (h byIndex) Len() int {
return len(h)
}
func (h byIndex) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
func (h byIndex) Less(i, j int) bool {
return h[i].GetHeader().Index < h[j].GetHeader().Index
}
func GetHorcruxes(paths []string) ([]Horcrux, error) {
horcruxes := []Horcrux{}
for _, path := range paths {
currentHorcrux, err := NewHorcrux(path)
if err != nil { if err != nil {
return err return nil, err
} }
for _, horcrux := range horcruxes {
currentHeader, err := getHeaderFromHorcruxFile(file) if horcrux.GetHeader().Index == currentHorcrux.GetHeader().Index && horcrux.GetHeader().OriginalFilename == currentHorcrux.GetHeader().OriginalFilename {
if err != nil {
return err
}
for _, header := range headers {
if header.Index == currentHeader.Index {
// we've already obtained this horcrux so we'll skip this instance // we've already obtained this horcrux so we'll skip this instance
continue continue
} }
} }
if len(headers) > 0 && (currentHeader.OriginalFilename != headers[0].OriginalFilename || currentHeader.Timestamp != headers[0].Timestamp) { horcruxes = append(horcruxes, *currentHorcrux)
}
sort.Sort(byIndex(horcruxes))
return horcruxes, nil
}
func ValidateHorcruxes(horcruxes []Horcrux) error {
if len(horcruxes) == 0 {
return errors.New("No horcruxes supplied")
}
if len(horcruxes) < horcruxes[0].GetHeader().Threshold {
return fmt.Errorf(
"You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d",
horcruxes[0].GetHeader().Threshold,
len(horcruxes),
)
}
for _, horcrux := range horcruxes {
if !strings.HasSuffix(horcrux.GetPath(), ".horcrux") {
return fmt.Errorf("%s is not a horcrux file (requires .horcrux extension)", horcrux.GetPath())
}
if horcrux.GetHeader().OriginalFilename != horcruxes[0].GetHeader().OriginalFilename || horcrux.GetHeader().Timestamp != horcruxes[0].GetHeader().Timestamp {
return errors.New("All horcruxes in the given directory must have the same original filename and timestamp.") return errors.New("All horcruxes in the given directory must have the same original filename and timestamp.")
} }
headers = append(headers, *currentHeader)
horcruxFiles = append(horcruxFiles, file)
} }
if len(headers) == 0 { return nil
return errors.New("No horcruxes in directory") }
} else if len(headers) < headers[0].Threshold {
return errors.New(fmt.Sprintf("You do not have all the required horcruxes. There are %d required to resurrect the original file. You only have %d", headers[0].Threshold, len(headers))) func Bind(paths []string, dstPath string, overwrite bool) error {
horcruxes, err := GetHorcruxes(paths)
if err != nil {
return err
} }
keyFragments := make([][]byte, len(headers)) if err := ValidateHorcruxes(horcruxes); err != nil {
return err
}
firstHorcrux := horcruxes[0]
// if dstPath is empty we use the original filename
if dstPath == "" {
cwd, err := os.Getwd()
if err != nil {
return err
}
dstPath = filepath.Join(cwd, firstHorcrux.GetHeader().OriginalFilename)
}
if fileExists(dstPath) && !overwrite {
return os.ErrExist
}
keyFragments := make([][]byte, len(horcruxes))
for i := range keyFragments { for i := range keyFragments {
keyFragments[i] = headers[i].KeyFragment keyFragments[i] = horcruxes[i].GetHeader().KeyFragment
} }
key, err := shamir.Combine(keyFragments) key, err := shamir.Combine(keyFragments)
@ -74,28 +128,22 @@ func Bind(dir string) error {
} }
var fileReader io.Reader var fileReader io.Reader
if headers[0].Total == headers[0].Threshold { if firstHorcrux.GetHeader().Total == firstHorcrux.GetHeader().Threshold {
// sort by index horcruxFiles := make([]*os.File, len(horcruxes))
orderedHorcruxFiles := make([]*os.File, len(horcruxFiles)) for i, horcrux := range horcruxes {
for i, h := range horcruxFiles { horcruxFiles[i] = horcrux.GetFile()
orderedHorcruxFiles[headers[i].Index-1] = h
} }
fileReader = &multiplexing.Multiplexer{Readers: orderedHorcruxFiles} fileReader = &multiplexing.Multiplexer{Readers: horcruxFiles}
} else { } else {
fileReader = horcruxFiles[0] // arbitrarily read from the first horcrux: they all contain the same contents fileReader = firstHorcrux.GetFile() // arbitrarily read from the first horcrux: they all contain the same contents
} }
reader := cryptoReader(fileReader, key) reader := cryptoReader(fileReader, key)
newFilename := headers[0].OriginalFilename _ = os.Truncate(dstPath, 0)
if fileExists(newFilename) {
newFilename = prompt("A file already exists named '%s'. Enter new file name: ", newFilename)
}
_ = os.Truncate(newFilename, 0) newFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE, 0644)
newFile, err := os.OpenFile(newFilename, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil { if err != nil {
return err return err
} }
@ -108,34 +156,3 @@ func Bind(dir string) error {
return err return err
} }
// this function gets the header from the horcrux file and ensures that we leave
// the file with its read pointer at the start of the encrypted content
// so that we can later directly read from that point
// yes this is a side effect, no I'm not proud of it.
func getHeaderFromHorcruxFile(file *os.File) (*horcruxHeader, error) {
currentHeader := &horcruxHeader{}
scanner := bufio.NewScanner(file)
bytesBeforeBody := 0
for scanner.Scan() {
line := scanner.Text()
bytesBeforeBody += len(scanner.Bytes()) + 1
if line == "-- HEADER --" {
scanner.Scan()
bytesBeforeBody += len(scanner.Bytes()) + 1
headerLine := scanner.Bytes()
json.Unmarshal(headerLine, currentHeader)
scanner.Scan() // one more to get past the body line
bytesBeforeBody += len(scanner.Bytes()) + 1
break
}
}
if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil {
return nil, err
}
if currentHeader == nil {
return nil, errors.New("could not find header in horcrux file")
}
return currentHeader, nil
}

90
pkg/commands/horcrux.go Normal file
View file

@ -0,0 +1,90 @@
package commands
import (
"bufio"
"encoding/json"
"errors"
"io"
"os"
)
type HorcruxHeader struct {
OriginalFilename string `json:"originalFilename"`
Timestamp int64 `json:"timestamp"`
Index int `json:"index"`
Total int `json:"total"`
Threshold int `json:"threshold"`
KeyFragment []byte `json:"keyFragment"`
}
type Horcrux struct {
path string
header HorcruxHeader
file *os.File
}
// returns a horcrux with its header parsed, and it's file's read pointer
// right after the header.
func NewHorcrux(path string) (*Horcrux, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
header, err := GetHeaderFromHorcruxFile(file)
if err != nil {
return nil, err
}
return &Horcrux{
path: path,
file: file,
header: *header,
}, nil
}
// this function gets the header from the horcrux file and ensures that we leave
// the file with its read pointer at the start of the encrypted content
// so that we can later directly read from that point
// yes this is a side effect, no I'm not proud of it.
func GetHeaderFromHorcruxFile(file *os.File) (*HorcruxHeader, error) {
currentHeader := &HorcruxHeader{}
scanner := bufio.NewScanner(file)
bytesBeforeBody := 0
for scanner.Scan() {
line := scanner.Text()
bytesBeforeBody += len(scanner.Bytes()) + 1
if line == "-- HEADER --" {
scanner.Scan()
bytesBeforeBody += len(scanner.Bytes()) + 1
headerLine := scanner.Bytes()
if err := json.Unmarshal(headerLine, currentHeader); err != nil {
return nil, err
}
scanner.Scan() // one more to get past the body line
bytesBeforeBody += len(scanner.Bytes()) + 1
break
}
}
if _, err := file.Seek(int64(bytesBeforeBody), io.SeekStart); err != nil {
return nil, err
}
if currentHeader == nil {
return nil, errors.New("could not find header in horcrux file")
}
return currentHeader, nil
}
func (h *Horcrux) GetHeader() HorcruxHeader {
return h.header
}
func (h *Horcrux) GetPath() string {
return h.path
}
func (h *Horcrux) GetFile() *os.File {
return h.file
}

View file

@ -1,10 +0,0 @@
package commands
type horcruxHeader struct {
OriginalFilename string `json:"originalFilename"`
Timestamp int64 `json:"timestamp"`
Index int `json:"index"`
Total int `json:"total"`
Threshold int `json:"threshold"`
KeyFragment []byte `json:"keyFragment"`
}

View file

@ -64,7 +64,7 @@ func Split(path string, destination string, total int, threshold int) error {
for i := range horcruxFiles { for i := range horcruxFiles {
index := i + 1 index := i + 1
headerBytes, err := json.Marshal(&horcruxHeader{ headerBytes, err := json.Marshal(&HorcruxHeader{
OriginalFilename: originalFilename, OriginalFilename: originalFilename,
Timestamp: timestamp, Timestamp: timestamp,
Index: index, Index: index,
@ -90,7 +90,9 @@ func Split(path string, destination string, total int, threshold int) error {
} }
defer horcruxFile.Close() defer horcruxFile.Close()
horcruxFile.WriteString(header(index, total, headerBytes)) if _, err := horcruxFile.WriteString(header(index, total, headerBytes)); err != nil {
return err
}
horcruxFiles[i] = horcruxFile horcruxFiles[i] = horcruxFile
} }
@ -133,7 +135,7 @@ func obtainTotalAndThreshold() (int, int, error) {
threshold := *thresholdPtr threshold := *thresholdPtr
if total == 0 { if total == 0 {
totalStr := prompt("How many horcruxes do you want to split this file into? (2-99): ") totalStr := Prompt("How many horcruxes do you want to split this file into? (2-99): ")
var err error var err error
total, err = strconv.Atoi(totalStr) total, err = strconv.Atoi(totalStr)
if err != nil { if err != nil {
@ -142,7 +144,7 @@ func obtainTotalAndThreshold() (int, int, error) {
} }
if threshold == 0 { if threshold == 0 {
thresholdStr := prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ") thresholdStr := Prompt("How many horcruxes should be required to reconstitute the original file? If you require all horcruxes, the resulting files will take up less space, but it will feel less magical (2-99): ")
var err error var err error
threshold, err = strconv.Atoi(thresholdStr) threshold, err = strconv.Atoi(thresholdStr)
if err != nil { if err != nil {

View file

@ -30,7 +30,7 @@ func fileExists(filename string) bool {
return !info.IsDir() return !info.IsDir()
} }
func prompt(message string, args ...interface{}) string { func Prompt(message string, args ...interface{}) string {
reader := bufio.NewReader(os.Stdin) reader := bufio.NewReader(os.Stdin)
fmt.Printf(message, args...) fmt.Printf(message, args...)
input, _ := reader.ReadString('\n') input, _ := reader.ReadString('\n')