在 GitHub 上查看

Mnist 模式分类策略

来源

本策略基于 MetaTrader 5 专家顾问 TestMnistOnnx.mq5(MQL ID 47225)改写。原始脚本提供一个交互式画布,让用户手绘数字并通过内置的 MNIST ONNX 模型识别。StockSharp 版本保留了模式识别的理念,但用基于完成 K 线的滚动矩阵替代了手绘画布。

核心思想

  1. 使用 LookbackPeriod(默认 28)根已完成的 K 线构造一个 28×28 的网格,模拟 MNIST 图像。
  2. 计算多个统计特征——区间宽度、趋势强度、动量、RSI 偏离和 ATR 标准化——并组合成一个模拟神经网络输出的“置信度”。
  3. 将这些特征映射到 0–9 的十个模式类别,每个类别代表一种市场状态(震荡、趋势、突破、回调、反转等)。
  4. 当检测到的类别等于用户设定的 TargetClass 且置信度高于 ConfidenceThreshold 时,策略按该类别对应的方向开仓或反手;若类别变化或置信度不足,则平仓。

参数

参数 默认值 说明
LookbackPeriod 28 转换为 MNIST 网格的完成 K 线数量。
TargetClass 1 触发交易的目标类别编号(0–9)。
ConfidenceThreshold 0.6 允许发单的最小置信度。
Volume 1 新仓位的下单量。
CandleType 5 分钟周期 用于更新的蜡烛数据类型。

模式类别

类别 含义
0 低波动或窄幅震荡。
1 持续的多头趋势。
2 持续的空头趋势。
3 伴随强势延续的向上突破。
4 伴随强势延续的向下突破。
5 高波动但无明显方向的宽幅震荡。
6 上升趋势中的多头回调。
7 下降趋势中的空头回调。
8 长期下跌后的多头反转。
9 长期上涨后的空头反转。

交易规则

  • 仅在蜡烛收盘后执行计算,以匹配原专家在绘图完成后才执行推理的行为。
  • 使用市价单(BuyMarketSellMarket),并在反手前先平掉原有仓位,保持单一持仓模式。
  • 置信度被限制在 [0, 1] 范围内,提高 ConfidenceThreshold 可以过滤掉较弱的信号。
  • 策略本身不设置止损止盈,风险控制应由 StockSharp 的保护工具或外部模块完成。

使用建议

  • 选择能体现目标市场节奏的蜡烛周期;周期越短,反应越快但噪音越多。
  • 需要同时调优 TargetClassConfidenceThreshold,部分模式较少出现,可适当降低阈值。
  • 分类器完全基于确定性计算,不依赖外部 ONNX 运行时库。
  • 建议结合 StartProtection 等 StockSharp 内建风险保护功能使用。

与原版的差异

  • 取消交互式绘图,改为自动分析历史 K 线。
  • “置信度”由指标组合生成,而非神经网络概率。
  • 新增了基于识别结果的自动交易逻辑。
  • 不再需要 MNIST 资源文件。
using System;
using System.Linq;
using System.Collections.Generic;

using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;

using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;

namespace StockSharp.Samples.Strategies;

/// <summary>
/// Port of the MetaTrader expert TestMnistOnnx.
/// Converts a rolling grid of candle closes into pattern classes and trades on a selected class.
/// The original mouse-drawn image is replaced with market data derived features.
/// </summary>
public class MnistPatternClassifierStrategy : Strategy
{
	private readonly StrategyParam<int> _lookbackPeriod;
	private readonly StrategyParam<int> _targetClass;
	private readonly StrategyParam<decimal> _confidenceThreshold;
	private readonly StrategyParam<DataType> _candleType;

	private RelativeStrengthIndex _rsi = null!;
	private AverageTrueRange _atr = null!;

	private readonly Queue<decimal> _closeWindow = new();
	private decimal _firstClose;
	private decimal _previousClose;

	private int _lastClass = -1;
	private decimal _lastConfidence;
	private int _cooldown;

	private enum PatternBiases
	{
		Neutral,
		Bullish,
		Bearish,
	}

	/// <summary>
	/// Number of finished candles that form the MNIST-like grid.
	/// </summary>
	public int LookbackPeriod
	{
		get => _lookbackPeriod.Value;
		set => _lookbackPeriod.Value = value;
	}

	/// <summary>
	/// Pattern class (0-9) that will trigger trading actions.
	/// </summary>
	public int TargetClass
	{
		get => _targetClass.Value;
		set => _targetClass.Value = value;
	}

	/// <summary>
	/// Minimum confidence required before orders are sent.
	/// </summary>
	public decimal ConfidenceThreshold
	{
		get => _confidenceThreshold.Value;
		set => _confidenceThreshold.Value = value;
	}


	/// <summary>
	/// Candle type that feeds the pattern grid.
	/// </summary>
	public DataType CandleType
	{
		get => _candleType.Value;
		set => _candleType.Value = value;
	}

	/// <summary>
	/// Initializes a new instance of the <see cref="MnistPatternClassifierStrategy"/> class.
	/// </summary>
	public MnistPatternClassifierStrategy()
	{
		_lookbackPeriod = Param(nameof(LookbackPeriod), 14)
		.SetRange(10, 200)

		.SetDisplay("Lookback", "Number of candles converted into the pattern grid", "Pattern");

		_targetClass = Param(nameof(TargetClass), 1)
		.SetRange(0, 9)
		.SetDisplay("Target Class", "Pattern class that should be traded", "Pattern");

		_confidenceThreshold = Param(nameof(ConfidenceThreshold), 0.2m)
		.SetRange(0m, 1m)
		.SetDisplay("Confidence", "Minimum classification confidence", "Pattern");


		_candleType = Param(nameof(CandleType), TimeSpan.FromMinutes(5).TimeFrame())
		.SetDisplay("Candle Type", "Primary timeframe used for the pattern", "General");
	}

	/// <inheritdoc />
	public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
	{
		yield return (Security, CandleType);
	}

	/// <inheritdoc />
	protected override void OnReseted()
	{
		base.OnReseted();

		_closeWindow.Clear();
		_firstClose = 0m;
		_previousClose = 0m;
		_lastClass = -1;
		_lastConfidence = 0m;
		_cooldown = 0;
	}

	/// <inheritdoc />
	protected override void OnStarted2(DateTime time)
	{
		base.OnStarted2(time);

		StartProtection(
			takeProfit: new Unit(3, UnitTypes.Percent),
			stopLoss: new Unit(2, UnitTypes.Percent));

		_rsi = new RelativeStrengthIndex
		{
			Length = LookbackPeriod,
		};

		_atr = new AverageTrueRange
		{
			Length = LookbackPeriod,
		};

		var subscription = SubscribeCandles(CandleType);
		subscription
		.Bind(_rsi, _atr, ProcessCandle)
		.Start();
	}

	private void ProcessCandle(ICandleMessage candle, decimal rsiValue, decimal atrValue)
	{
		if (candle.State != CandleStates.Finished)
		return;

		UpdateWindow(candle.ClosePrice);

		if (_closeWindow.Count < LookbackPeriod)
		{
			_previousClose = candle.ClosePrice;
			return;
		}

		if (_cooldown > 0)
		{
			_cooldown--;
			_previousClose = candle.ClosePrice;
			return;
		}

		var pattern = ClassifyPattern(candle.ClosePrice, rsiValue, atrValue);

		_lastClass = pattern.PatternClass;
		_lastConfidence = pattern.Confidence;

		if (pattern.Confidence >= ConfidenceThreshold && pattern.Bias != PatternBiases.Neutral && Position == 0)
		{
			ExecuteBias(pattern.Bias);
		}

		_previousClose = candle.ClosePrice;
	}

	private void ExecuteBias(PatternBiases bias)
	{
		switch (bias)
		{
		case PatternBiases.Bullish:
			if (Position == 0)
			{
				BuyMarket();
				_cooldown = 50;
			}

			break;
		case PatternBiases.Bearish:
			if (Position == 0)
			{
				SellMarket();
				_cooldown = 50;
			}

			break;
		default:
			break;
		}
	}

	private void UpdateWindow(decimal close)
	{
		_closeWindow.Enqueue(close);

		if (_closeWindow.Count > LookbackPeriod)
		{
			_closeWindow.Dequeue();
		}

		_firstClose = _closeWindow.Count > 0 ? _closeWindow.Peek() : 0m;
	}

	private PatternResult ClassifyPattern(decimal currentClose, decimal rsiValue, decimal atrValue)
	{
		var stats = CalculateStatistics(currentClose, rsiValue, atrValue);

		var trendStrength = stats.TrendStrength;
		var rangeStrength = stats.RangeStrength;
		var breakoutRange = stats.BreakoutThreshold;
		var rangePosition = stats.RangePosition;
		var momentum = stats.Momentum;
		var rsi = stats.Rsi;
		var atr = stats.AtrNormalized;

		// Compute a blended confidence score similar to the ONNX output probability.
		var confidence = Math.Min(1m, (trendStrength + rangeStrength + Math.Min(1m, Math.Abs(momentum) / stats.MomentumThreshold) + stats.RsiDeviation + atr) / 5m);

		if (rangeStrength < stats.FlatThreshold)
		{
			return new PatternResult(0, Math.Max(confidence, 0.4m), PatternBiases.Neutral);
		}

		if (trendStrength >= stats.TrendThreshold)
		{
			if (rangePosition >= 0.75m && rangeStrength >= breakoutRange)
			{
				return new PatternResult(3, confidence, PatternBiases.Bullish);
			}

			if (momentum < 0m)
			{
				return new PatternResult(6, confidence * 0.8m, PatternBiases.Bullish);
			}

			return new PatternResult(1, confidence, PatternBiases.Bullish);
		}

		if (trendStrength <= -stats.TrendThreshold)
		{
			if (rangePosition <= 0.25m && rangeStrength >= breakoutRange)
			{
				return new PatternResult(4, confidence, PatternBiases.Bearish);
			}

			if (momentum > 0m)
			{
				return new PatternResult(7, confidence * 0.8m, PatternBiases.Bearish);
			}

			return new PatternResult(2, confidence, PatternBiases.Bearish);
		}

		if (rangeStrength >= breakoutRange)
		{
			return new PatternResult(5, confidence * 0.9m, PatternBiases.Neutral);
		}

		if (rangePosition <= 0.4m && rsi >= 55m)
		{
			return new PatternResult(8, confidence * 0.85m, PatternBiases.Bullish);
		}

		if (rangePosition >= 0.6m && rsi <= 45m)
		{
			return new PatternResult(9, confidence * 0.85m, PatternBiases.Bearish);
		}

		return new PatternResult(0, confidence * 0.7m, PatternBiases.Neutral);
	}

	private PatternStatistics CalculateStatistics(decimal currentClose, decimal rsiValue, decimal atrValue)
	{
		var window = _closeWindow.ToArray();
		decimal min = decimal.MaxValue;
		decimal max = decimal.MinValue;

		foreach (var value in window)
		{
			if (value < min)
			min = value;

			if (value > max)
			max = value;
		}

		var first = _firstClose;
		var last = currentClose;
		var range = max - min;
		var rangeStrength = first != 0m ? range / first : 0m;
		var trend = first != 0m ? (last - first) / first : 0m;
		var momentum = _previousClose != 0m ? (last - _previousClose) / _previousClose : 0m;
		var rsiDeviation = Math.Min(1m, Math.Abs(rsiValue - 50m) / 50m);
		var atrNormalized = first != 0m ? Math.Min(1m, atrValue / first) : 0m;

		var rangePosition = range > 0m ? (last - min) / range : 0.5m;

		const decimal baseThreshold = 0.001m;
		var trendThreshold = baseThreshold;
		var breakoutThreshold = baseThreshold * 1.4m;
		var flatThreshold = baseThreshold * 0.3m;
		var momentumThreshold = baseThreshold;

		return new PatternStatistics
		{
			TrendStrength = trend,
			RangeStrength = rangeStrength,
			BreakoutThreshold = breakoutThreshold,
			FlatThreshold = flatThreshold,
			RangePosition = rangePosition,
			Momentum = momentum,
			Rsi = rsiValue,
			AtrNormalized = atrNormalized,
			TrendThreshold = trendThreshold,
			MomentumThreshold = momentumThreshold,
			RsiDeviation = rsiDeviation,
		};
	}

	private readonly struct PatternResult
	{
		public PatternResult(int patternClass, decimal confidence, PatternBiases bias)
		{
			PatternClass = patternClass;
			Confidence = confidence;
			Bias = bias;
		}

		public int PatternClass { get; }

		public decimal Confidence { get; }

		public PatternBiases Bias { get; }
	}

	private readonly struct PatternStatistics
	{
		public decimal TrendStrength { get; init; }
		public decimal RangeStrength { get; init; }
		public decimal BreakoutThreshold { get; init; }
		public decimal FlatThreshold { get; init; }
		public decimal RangePosition { get; init; }
		public decimal Momentum { get; init; }
		public decimal Rsi { get; init; }
		public decimal AtrNormalized { get; init; }
		public decimal TrendThreshold { get; init; }
		public decimal MomentumThreshold { get; init; }
		public decimal RsiDeviation { get; init; }
	}
}