diff --git a/.gitignore b/.gitignore index 9375eb5..c1488bc 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ .vscode .idea vendor + +# Temp directory +tmp/ diff --git a/examples/mean-reversion/mean-reversion.go b/examples/mean-reversion/mean-reversion.go index 88294f1..98706d1 100644 --- a/examples/mean-reversion/mean-reversion.go +++ b/examples/mean-reversion/mean-reversion.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "os" + "sync/atomic" "time" movingaverage "github.com/RobinUS2/golang-moving-average" @@ -19,7 +20,7 @@ const ( windowSize = 20 ) -type alpacaClientContainer struct { +type algo struct { tradeClient *alpaca.Client dataClient *marketdata.Client streamClient *stream.StocksClient @@ -27,11 +28,10 @@ type alpacaClientContainer struct { movingAverage *movingaverage.MovingAverage lastOrder string stock string + shouldTrade atomic.Bool } -var algo alpacaClientContainer - -func init() { +func main() { // You can set your API key/secret here or you can use environment variables! apiKey := "" apiSecret := "" @@ -40,12 +40,13 @@ func init() { // Change feed to sip if you have proper subscription feed := "iex" - // Check if user input a stock, default is AAPL - stock := "AAPL" - if len(os.Args[1:]) == 1 { - stock = os.Args[1] + symbol := "AAPL" + if len(os.Args) > 1 { + symbol = os.Args[1] } - algo = alpacaClientContainer{ + fmt.Println("Selected symbol: " + symbol) + + a := &algo{ tradeClient: alpaca.NewClient(alpaca.ClientOpts{ APIKey: apiKey, APISecret: apiSecret, @@ -60,22 +61,23 @@ func init() { ), feed: feed, movingAverage: movingaverage.New(windowSize), - stock: stock, + stock: symbol, } -} -func main() { fmt.Println("Cancelling all open orders so they don't impact our buying power...") - orders, err := algo.tradeClient.GetOrders(alpaca.GetOrdersRequest{ + orders, err := a.tradeClient.GetOrders(alpaca.GetOrdersRequest{ Status: "open", Until: time.Now(), Limit: 100, }) + for _, order := range orders { + fmt.Printf("%+v\n", order) + } if err != nil { log.Fatalf("Failed to list orders: %v", err) } for _, order := range orders { - if err := algo.tradeClient.CancelOrder(order.ID); err != nil { + if err := a.tradeClient.CancelOrder(order.ID); err != nil { log.Fatalf("Failed to cancel orders: %v", err) } } @@ -84,21 +86,21 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := algo.streamClient.Connect(ctx); err != nil { + if err := a.streamClient.Connect(ctx); err != nil { log.Fatalf("Failed to connect to the marketdata stream: %v", err) } - if err := algo.streamClient.SubscribeToBars(algo.onBar, algo.stock); err != nil { + if err := a.streamClient.SubscribeToBars(a.onBar, a.stock); err != nil { log.Fatalf("Failed to subscribe to the bars stream: %v", err) } go func() { - if err := <-algo.streamClient.Terminated(); err != nil { + if err := <-a.streamClient.Terminated(); err != nil { log.Fatalf("The marketdata stream was terminated: %v", err) } }() for { - isOpen, err := algo.awaitMarketOpen() + isOpen, err := a.awaitMarketOpen() if err != nil { log.Fatalf("Failed to wait for market open: %v", err) } @@ -106,27 +108,28 @@ func main() { time.Sleep(1 * time.Minute) continue } - fmt.Printf("The market is open! Waiting for %s minute bars...\n", algo.stock) + fmt.Printf("The market is open! Waiting for %s minute bars...\n", a.stock) // Reset the moving average for the day - algo.movingAverage = movingaverage.New(windowSize) + a.movingAverage = movingaverage.New(windowSize) - bars, err := algo.dataClient.GetBars(algo.stock, marketdata.GetBarsRequest{ + bars, err := a.dataClient.GetBars(a.stock, marketdata.GetBarsRequest{ TimeFrame: marketdata.OneMin, Start: time.Now().Add(-1 * (windowSize + 1) * time.Minute), End: time.Now(), - Feed: algo.feed, + Feed: a.feed, }) if err != nil { log.Fatalf("Failed to get historical bar: %v", err) } for _, bar := range bars { - algo.movingAverage.Add(bar.Close) + a.movingAverage.Add(bar.Close) } + a.shouldTrade.Store(true) // During market open we react on the minute bars (onBar) - clock, err := algo.tradeClient.GetClock() + clock, err := a.tradeClient.GetClock() if err != nil { log.Fatalf("Failed to get clock: %v", err) } @@ -134,44 +137,40 @@ func main() { time.Sleep(untilClose) fmt.Println("Market closing soon. Closing position.") - if _, err := algo.tradeClient.ClosePosition(algo.stock, alpaca.ClosePositionRequest{}); err != nil { - log.Fatalf("Failed to close position: %v", algo.stock) + a.shouldTrade.Store(false) + if _, err := a.tradeClient.ClosePosition(a.stock, alpaca.ClosePositionRequest{}); err != nil { + log.Fatalf("Failed to close position: %v", a.stock) } fmt.Println("Position closed.") } } -func (alp alpacaClientContainer) onBar(bar stream.Bar) { - clock, err := algo.tradeClient.GetClock() - if err != nil { - fmt.Println("Failed to get clock:", err) - return - } - if !clock.IsOpen { +func (a *algo) onBar(bar stream.Bar) { + if !a.shouldTrade.Load() { return } - if algo.lastOrder != "" { - _ = alp.tradeClient.CancelOrder(algo.lastOrder) + if a.lastOrder != "" { + _ = a.tradeClient.CancelOrder(a.lastOrder) } - algo.movingAverage.Add(bar.Close) - count := algo.movingAverage.Count() + a.movingAverage.Add(bar.Close) + count := a.movingAverage.Count() if count < windowSize { fmt.Printf("Waiting for %d bars, now we have %d", windowSize, count) return } - avg := algo.movingAverage.Avg() + avg := a.movingAverage.Avg() fmt.Printf("Latest minute bar close price: %g, latest %d average: %g\n", bar.Close, windowSize, avg) - if err := algo.rebalance(bar.Close, avg); err != nil { + if err := a.rebalance(bar.Close, avg); err != nil { fmt.Println("Failed to rebalance:", err) } } // Spin until the market is open. -func (alp alpacaClientContainer) awaitMarketOpen() (bool, error) { - clock, err := algo.tradeClient.GetClock() +func (a *algo) awaitMarketOpen() (bool, error) { + clock, err := a.tradeClient.GetClock() if err != nil { return false, fmt.Errorf("get clock: %w", err) } @@ -184,11 +183,11 @@ func (alp alpacaClientContainer) awaitMarketOpen() (bool, error) { } // Rebalance our position after an update. -func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { +func (a *algo) rebalance(currPrice, avg float64) error { // Get our position, if any. positionQty := 0 positionVal := 0.0 - position, err := alp.tradeClient.GetPosition(algo.stock) + position, err := a.tradeClient.GetPosition(a.stock) if err != nil { if apiErr, ok := err.(*alpaca.APIError); !ok || apiErr.Message != "position does not exist" { return fmt.Errorf("get position: %w", err) @@ -202,7 +201,7 @@ func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { // Sell our position if the price is above the running average, if any. if positionQty > 0 { fmt.Println("Setting long position to zero") - if err := alp.submitLimitOrder(positionQty, algo.stock, currPrice, "sell"); err != nil { + if err := a.submitLimitOrder(positionQty, a.stock, currPrice, "sell"); err != nil { return fmt.Errorf("submit limit order: %v", err) } } else { @@ -210,12 +209,12 @@ func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { } } else if currPrice < avg { // Determine optimal amount of shares based on portfolio and market data. - account, err := alp.tradeClient.GetAccount() + account, err := a.tradeClient.GetAccount() if err != nil { return fmt.Errorf("get account: %w", err) } buyingPower, _ := account.BuyingPower.Float64() - positions, err := alp.tradeClient.GetPositions() + positions, err := a.tradeClient.GetPositions() if err != nil { return fmt.Errorf("list positions: %w", err) } @@ -234,7 +233,7 @@ func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { amountToAdd = buyingPower } qtyToBuy := int(amountToAdd / currPrice) - if err := alp.submitLimitOrder(qtyToBuy, algo.stock, currPrice, "buy"); err != nil { + if err := a.submitLimitOrder(qtyToBuy, a.stock, currPrice, "buy"); err != nil { return fmt.Errorf("submit limit order: %v", err) } } else { @@ -243,7 +242,7 @@ func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { if qtyToSell > positionQty { qtyToSell = positionQty } - if err := alp.submitLimitOrder(qtyToSell, algo.stock, currPrice, "sell"); err != nil { + if err := a.submitLimitOrder(qtyToSell, a.stock, currPrice, "sell"); err != nil { return fmt.Errorf("submit limit order: %v", err) } } @@ -252,13 +251,13 @@ func (alp alpacaClientContainer) rebalance(currPrice, avg float64) error { } // Submit a limit order if quantity is above 0. -func (alp alpacaClientContainer) submitLimitOrder(qty int, symbol string, price float64, side string) error { +func (a *algo) submitLimitOrder(qty int, symbol string, price float64, side string) error { if qty <= 0 { fmt.Printf("Quantity is <= 0, order of | %d %s %s | not sent.\n", qty, symbol, side) } adjSide := alpaca.Side(side) decimalQty := decimal.NewFromInt(int64(qty)) - order, err := alp.tradeClient.PlaceOrder(alpaca.PlaceOrderRequest{ + order, err := a.tradeClient.PlaceOrder(alpaca.PlaceOrderRequest{ Symbol: symbol, Qty: &decimalQty, Side: adjSide, @@ -270,6 +269,6 @@ func (alp alpacaClientContainer) submitLimitOrder(qty int, symbol string, price return fmt.Errorf("qty=%d symbol=%s side=%s: %w", qty, symbol, side, err) } fmt.Printf("Limit order of | %d %s %s | sent.\n", qty, symbol, side) - algo.lastOrder = order.ID + a.lastOrder = order.ID return nil }