using System;
using System.Linq;
using System.Collections.Generic;
using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;
using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;
namespace StockSharp.Samples.Strategies;
/// <summary>
/// Port of the MetaTrader expert TestMnistOnnx.
/// Converts a rolling grid of candle closes into pattern classes and trades on a selected class.
/// The original mouse-drawn image is replaced with market data derived features.
/// </summary>
public class MnistPatternClassifierStrategy : Strategy
{
private readonly StrategyParam<int> _lookbackPeriod;
private readonly StrategyParam<int> _targetClass;
private readonly StrategyParam<decimal> _confidenceThreshold;
private readonly StrategyParam<DataType> _candleType;
private RelativeStrengthIndex _rsi = null!;
private AverageTrueRange _atr = null!;
private readonly Queue<decimal> _closeWindow = new();
private decimal _firstClose;
private decimal _previousClose;
private int _lastClass = -1;
private decimal _lastConfidence;
private int _cooldown;
private enum PatternBiases
{
Neutral,
Bullish,
Bearish,
}
/// <summary>
/// Number of finished candles that form the MNIST-like grid.
/// </summary>
public int LookbackPeriod
{
get => _lookbackPeriod.Value;
set => _lookbackPeriod.Value = value;
}
/// <summary>
/// Pattern class (0-9) that will trigger trading actions.
/// </summary>
public int TargetClass
{
get => _targetClass.Value;
set => _targetClass.Value = value;
}
/// <summary>
/// Minimum confidence required before orders are sent.
/// </summary>
public decimal ConfidenceThreshold
{
get => _confidenceThreshold.Value;
set => _confidenceThreshold.Value = value;
}
/// <summary>
/// Candle type that feeds the pattern grid.
/// </summary>
public DataType CandleType
{
get => _candleType.Value;
set => _candleType.Value = value;
}
/// <summary>
/// Initializes a new instance of the <see cref="MnistPatternClassifierStrategy"/> class.
/// </summary>
public MnistPatternClassifierStrategy()
{
_lookbackPeriod = Param(nameof(LookbackPeriod), 14)
.SetRange(10, 200)
.SetDisplay("Lookback", "Number of candles converted into the pattern grid", "Pattern");
_targetClass = Param(nameof(TargetClass), 1)
.SetRange(0, 9)
.SetDisplay("Target Class", "Pattern class that should be traded", "Pattern");
_confidenceThreshold = Param(nameof(ConfidenceThreshold), 0.2m)
.SetRange(0m, 1m)
.SetDisplay("Confidence", "Minimum classification confidence", "Pattern");
_candleType = Param(nameof(CandleType), TimeSpan.FromMinutes(5).TimeFrame())
.SetDisplay("Candle Type", "Primary timeframe used for the pattern", "General");
}
/// <inheritdoc />
public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
{
yield return (Security, CandleType);
}
/// <inheritdoc />
protected override void OnReseted()
{
base.OnReseted();
_closeWindow.Clear();
_firstClose = 0m;
_previousClose = 0m;
_lastClass = -1;
_lastConfidence = 0m;
_cooldown = 0;
}
/// <inheritdoc />
protected override void OnStarted2(DateTime time)
{
base.OnStarted2(time);
StartProtection(
takeProfit: new Unit(3, UnitTypes.Percent),
stopLoss: new Unit(2, UnitTypes.Percent));
_rsi = new RelativeStrengthIndex
{
Length = LookbackPeriod,
};
_atr = new AverageTrueRange
{
Length = LookbackPeriod,
};
var subscription = SubscribeCandles(CandleType);
subscription
.Bind(_rsi, _atr, ProcessCandle)
.Start();
}
private void ProcessCandle(ICandleMessage candle, decimal rsiValue, decimal atrValue)
{
if (candle.State != CandleStates.Finished)
return;
UpdateWindow(candle.ClosePrice);
if (_closeWindow.Count < LookbackPeriod)
{
_previousClose = candle.ClosePrice;
return;
}
if (_cooldown > 0)
{
_cooldown--;
_previousClose = candle.ClosePrice;
return;
}
var pattern = ClassifyPattern(candle.ClosePrice, rsiValue, atrValue);
_lastClass = pattern.PatternClass;
_lastConfidence = pattern.Confidence;
if (pattern.Confidence >= ConfidenceThreshold && pattern.Bias != PatternBiases.Neutral && Position == 0)
{
ExecuteBias(pattern.Bias);
}
_previousClose = candle.ClosePrice;
}
private void ExecuteBias(PatternBiases bias)
{
switch (bias)
{
case PatternBiases.Bullish:
if (Position == 0)
{
BuyMarket();
_cooldown = 50;
}
break;
case PatternBiases.Bearish:
if (Position == 0)
{
SellMarket();
_cooldown = 50;
}
break;
default:
break;
}
}
private void UpdateWindow(decimal close)
{
_closeWindow.Enqueue(close);
if (_closeWindow.Count > LookbackPeriod)
{
_closeWindow.Dequeue();
}
_firstClose = _closeWindow.Count > 0 ? _closeWindow.Peek() : 0m;
}
private PatternResult ClassifyPattern(decimal currentClose, decimal rsiValue, decimal atrValue)
{
var stats = CalculateStatistics(currentClose, rsiValue, atrValue);
var trendStrength = stats.TrendStrength;
var rangeStrength = stats.RangeStrength;
var breakoutRange = stats.BreakoutThreshold;
var rangePosition = stats.RangePosition;
var momentum = stats.Momentum;
var rsi = stats.Rsi;
var atr = stats.AtrNormalized;
// Compute a blended confidence score similar to the ONNX output probability.
var confidence = Math.Min(1m, (trendStrength + rangeStrength + Math.Min(1m, Math.Abs(momentum) / stats.MomentumThreshold) + stats.RsiDeviation + atr) / 5m);
if (rangeStrength < stats.FlatThreshold)
{
return new PatternResult(0, Math.Max(confidence, 0.4m), PatternBiases.Neutral);
}
if (trendStrength >= stats.TrendThreshold)
{
if (rangePosition >= 0.75m && rangeStrength >= breakoutRange)
{
return new PatternResult(3, confidence, PatternBiases.Bullish);
}
if (momentum < 0m)
{
return new PatternResult(6, confidence * 0.8m, PatternBiases.Bullish);
}
return new PatternResult(1, confidence, PatternBiases.Bullish);
}
if (trendStrength <= -stats.TrendThreshold)
{
if (rangePosition <= 0.25m && rangeStrength >= breakoutRange)
{
return new PatternResult(4, confidence, PatternBiases.Bearish);
}
if (momentum > 0m)
{
return new PatternResult(7, confidence * 0.8m, PatternBiases.Bearish);
}
return new PatternResult(2, confidence, PatternBiases.Bearish);
}
if (rangeStrength >= breakoutRange)
{
return new PatternResult(5, confidence * 0.9m, PatternBiases.Neutral);
}
if (rangePosition <= 0.4m && rsi >= 55m)
{
return new PatternResult(8, confidence * 0.85m, PatternBiases.Bullish);
}
if (rangePosition >= 0.6m && rsi <= 45m)
{
return new PatternResult(9, confidence * 0.85m, PatternBiases.Bearish);
}
return new PatternResult(0, confidence * 0.7m, PatternBiases.Neutral);
}
private PatternStatistics CalculateStatistics(decimal currentClose, decimal rsiValue, decimal atrValue)
{
var window = _closeWindow.ToArray();
decimal min = decimal.MaxValue;
decimal max = decimal.MinValue;
foreach (var value in window)
{
if (value < min)
min = value;
if (value > max)
max = value;
}
var first = _firstClose;
var last = currentClose;
var range = max - min;
var rangeStrength = first != 0m ? range / first : 0m;
var trend = first != 0m ? (last - first) / first : 0m;
var momentum = _previousClose != 0m ? (last - _previousClose) / _previousClose : 0m;
var rsiDeviation = Math.Min(1m, Math.Abs(rsiValue - 50m) / 50m);
var atrNormalized = first != 0m ? Math.Min(1m, atrValue / first) : 0m;
var rangePosition = range > 0m ? (last - min) / range : 0.5m;
const decimal baseThreshold = 0.001m;
var trendThreshold = baseThreshold;
var breakoutThreshold = baseThreshold * 1.4m;
var flatThreshold = baseThreshold * 0.3m;
var momentumThreshold = baseThreshold;
return new PatternStatistics
{
TrendStrength = trend,
RangeStrength = rangeStrength,
BreakoutThreshold = breakoutThreshold,
FlatThreshold = flatThreshold,
RangePosition = rangePosition,
Momentum = momentum,
Rsi = rsiValue,
AtrNormalized = atrNormalized,
TrendThreshold = trendThreshold,
MomentumThreshold = momentumThreshold,
RsiDeviation = rsiDeviation,
};
}
private readonly struct PatternResult
{
public PatternResult(int patternClass, decimal confidence, PatternBiases bias)
{
PatternClass = patternClass;
Confidence = confidence;
Bias = bias;
}
public int PatternClass { get; }
public decimal Confidence { get; }
public PatternBiases Bias { get; }
}
private readonly struct PatternStatistics
{
public decimal TrendStrength { get; init; }
public decimal RangeStrength { get; init; }
public decimal BreakoutThreshold { get; init; }
public decimal FlatThreshold { get; init; }
public decimal RangePosition { get; init; }
public decimal Momentum { get; init; }
public decimal Rsi { get; init; }
public decimal AtrNormalized { get; init; }
public decimal TrendThreshold { get; init; }
public decimal MomentumThreshold { get; init; }
public decimal RsiDeviation { get; init; }
}
}
import clr
from collections import deque
clr.AddReference("StockSharp.Messages")
clr.AddReference("StockSharp.Algo")
clr.AddReference("StockSharp.Algo.Indicators")
clr.AddReference("StockSharp.Algo.Strategies")
from System import TimeSpan
from StockSharp.Messages import DataType, CandleStates, Unit, UnitTypes
from StockSharp.Algo.Indicators import RelativeStrengthIndex, AverageTrueRange
from StockSharp.Algo.Strategies import Strategy
class mnist_pattern_classifier_strategy(Strategy):
def __init__(self):
super(mnist_pattern_classifier_strategy, self).__init__()
self._candle_type = self.Param("CandleType", DataType.TimeFrame(TimeSpan.FromMinutes(5)))
self._lookback_period = self.Param("LookbackPeriod", 14)
self._target_class = self.Param("TargetClass", 1)
self._confidence_threshold = self.Param("ConfidenceThreshold", 0.2)
self._close_window = deque()
self._first_close = 0.0
self._previous_close = 0.0
self._last_class = -1
self._last_confidence = 0.0
self._cooldown = 0
@property
def CandleType(self):
return self._candle_type.Value
@CandleType.setter
def CandleType(self, value):
self._candle_type.Value = value
@property
def LookbackPeriod(self):
return self._lookback_period.Value
@LookbackPeriod.setter
def LookbackPeriod(self, value):
self._lookback_period.Value = value
@property
def TargetClass(self):
return self._target_class.Value
@TargetClass.setter
def TargetClass(self, value):
self._target_class.Value = value
@property
def ConfidenceThreshold(self):
return self._confidence_threshold.Value
@ConfidenceThreshold.setter
def ConfidenceThreshold(self, value):
self._confidence_threshold.Value = value
def OnReseted(self):
super(mnist_pattern_classifier_strategy, self).OnReseted()
self._close_window = deque()
self._first_close = 0.0
self._previous_close = 0.0
self._last_class = -1
self._last_confidence = 0.0
self._cooldown = 0
def OnStarted2(self, time):
super(mnist_pattern_classifier_strategy, self).OnStarted2(time)
self._close_window = deque()
self._first_close = 0.0
self._previous_close = 0.0
self._last_class = -1
self._last_confidence = 0.0
self._cooldown = 0
self.StartProtection(
takeProfit=Unit(3, UnitTypes.Percent),
stopLoss=Unit(2, UnitTypes.Percent))
rsi = RelativeStrengthIndex()
rsi.Length = self.LookbackPeriod
atr = AverageTrueRange()
atr.Length = self.LookbackPeriod
subscription = self.SubscribeCandles(self.CandleType)
subscription.Bind(rsi, atr, self._process_candle).Start()
def _update_window(self, close):
lookback = self.LookbackPeriod
self._close_window.append(close)
while len(self._close_window) > lookback:
self._close_window.popleft()
self._first_close = self._close_window[0] if len(self._close_window) > 0 else 0.0
def _classify_pattern(self, current_close, rsi_val, atr_val):
lookback = self.LookbackPeriod
window = list(self._close_window)
min_val = min(window) if window else 0.0
max_val = max(window) if window else 0.0
first = self._first_close
last = current_close
r = max_val - min_val
range_strength = r / first if first != 0 else 0.0
trend = (last - first) / first if first != 0 else 0.0
momentum = (last - self._previous_close) / self._previous_close if self._previous_close != 0 else 0.0
rsi_deviation = min(1.0, abs(rsi_val - 50.0) / 50.0)
atr_normalized = min(1.0, atr_val / first) if first != 0 else 0.0
range_position = (last - min_val) / r if r > 0 else 0.5
base_threshold = 0.001
trend_threshold = base_threshold
breakout_threshold = base_threshold * 1.4
flat_threshold = base_threshold * 0.3
momentum_threshold = base_threshold
# Confidence
confidence = min(1.0, (abs(trend) + range_strength + min(1.0, abs(momentum) / momentum_threshold) + rsi_deviation + atr_normalized) / 5.0)
# Neutral = 0, Bullish = 1, Bearish = 2
if range_strength < flat_threshold:
return (0, max(confidence, 0.4), 0)
if trend >= trend_threshold:
if range_position >= 0.75 and range_strength >= breakout_threshold:
return (3, confidence, 1)
if momentum < 0:
return (6, confidence * 0.8, 1)
return (1, confidence, 1)
if trend <= -trend_threshold:
if range_position <= 0.25 and range_strength >= breakout_threshold:
return (4, confidence, 2)
if momentum > 0:
return (7, confidence * 0.8, 2)
return (2, confidence, 2)
if range_strength >= breakout_threshold:
return (5, confidence * 0.9, 0)
if range_position <= 0.4 and rsi_val >= 55.0:
return (8, confidence * 0.85, 1)
if range_position >= 0.6 and rsi_val <= 45.0:
return (9, confidence * 0.85, 2)
return (0, confidence * 0.7, 0)
def _process_candle(self, candle, rsi_value, atr_value):
if candle.State != CandleStates.Finished:
return
close = float(candle.ClosePrice)
rsi_val = float(rsi_value)
atr_val = float(atr_value)
self._update_window(close)
lookback = self.LookbackPeriod
if len(self._close_window) < lookback:
self._previous_close = close
return
if self._cooldown > 0:
self._cooldown -= 1
self._previous_close = close
return
pattern_class, confidence, bias = self._classify_pattern(close, rsi_val, atr_val)
self._last_class = pattern_class
self._last_confidence = confidence
conf_threshold = float(self.ConfidenceThreshold)
if confidence >= conf_threshold and bias != 0 and self.Position == 0:
if bias == 1:
self.BuyMarket()
self._cooldown = 50
elif bias == 2:
self.SellMarket()
self._cooldown = 50
self._previous_close = close
def CreateClone(self):
return mnist_pattern_classifier_strategy()