在 GitHub 上查看

矩阵机器学习策略

概述

矩阵机器学习策略源自 MetaTrader 5 平台上的 "MQL5Book" 教程脚本。原版程序收集一段最新报价,将相邻价格差分转换成二进制序列,并使用霍普菲尔德递归神经网络进行训练。训练好的网络会在样本内与样本外片段上进行验证,最终根据预测向量的第一个元素(+1 代表看涨,-1 代表看跌)决定交易方向。

C# 版本采用 StockSharp 高级 API,并使用收盘完成的 K 线代替逐笔数据,以获得更稳定的跨平台表现。每当 K 线收盘时,策略都会更新价格二进制模式、重新训练霍普菲尔德网络、评估历史准确度,并生成下一步的在线预测。

算法细节

  1. 收集最近 HistoryDepth 根 K 线的收盘价。最新的 ForwardDepth 个点作为样本外验证段,其余数据用于训练。
  2. 将连续收盘价的差值转换为二进制序列:非负差值映射为 +1,负差值映射为 -1
  3. 根据 PredictorLength(输入长度)与 ForecastLength(输出长度)滑动窗口计算外积之和,构造霍普菲尔德权重矩阵。
  4. 在训练集与验证集上评估该矩阵。准确度指标与原脚本保持一致:对预测向量与真实向量求点积,取平均后换算成百分比。
  5. 构建最新的在线二进制模式,运行霍普菲尔德推理循环(tanh 激活与收敛阈值),预测向量的第一个元素用于驱动交易决策。

参数说明

  • History Depth – 霍普菲尔德网络持有的历史收盘数量,应大于 ForwardDepth,且至少满足 PredictorLength + ForecastLength + 1
  • Forward Depth – 保留用于样本外验证的收盘数量,至少需要 ForecastLength + 1 根 K 线。
  • Predictor Length – 神经网络输入向量的长度。
  • Forecast Length – 神经网络输出向量预测的未来步数。
  • Candle Type – 指定由连接器订阅的 K 线数据类型。
  • Debug Log – 启用后会在日志中输出详细的中间向量、样本比较以及在线预测结果。

交易逻辑

  • 当预测向量的首个元素为正且当前持仓≤0 时,策略以 Volume + |Position| 的数量市价买入,转为多头。
  • 当预测向量的首个元素为负且当前持仓≥0 时,策略以 Volume + |Position| 的数量市价卖出,转为空头。
  • 若预测结果为零,则忽略信号以避免不必要的交易频率。

若图表区域可用,策略会自动绘制 K 线与成交记录。霍普菲尔德网络在每根 K 线收盘时重新训练,使权重始终跟随最新的市场结构。

using System;
using System.Linq;
using System.Collections.Generic;

using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;

using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;

using System.Text;

namespace StockSharp.Samples.Strategies;

/// <summary>
/// Strategy implementing a Hopfield neural network trained on price direction sequences.
/// </summary>
public class MatrixMachineLearningStrategy : Strategy
{
	private readonly StrategyParam<int> _maxIterations;
	private readonly StrategyParam<double> _accuracy;

	private readonly StrategyParam<int> _historyDepth;
	private readonly StrategyParam<int> _forwardDepth;
	private readonly StrategyParam<int> _predictorLength;
	private readonly StrategyParam<int> _forecastLength;
	private readonly StrategyParam<DataType> _candleType;
	private readonly StrategyParam<bool> _enableDebugLog;

	private readonly List<decimal> _closes = new();

	private double[,] _weights;

	/// <summary>
	/// Number of most recent candle closes used for training.
	/// </summary>
	public int HistoryDepth
	{
		get => _historyDepth.Value;
		set => _historyDepth.Value = value;
	}

	/// <summary>
	/// Portion of the history reserved for forward evaluation.
	/// </summary>
	public int ForwardDepth
	{
		get => _forwardDepth.Value;
		set => _forwardDepth.Value = value;
	}

	/// <summary>
	/// Number of binary price movements forming the network input vector.
	/// </summary>
	public int PredictorLength
	{
		get => _predictorLength.Value;
		set => _predictorLength.Value = value;
	}

	/// <summary>
	/// Number of steps predicted by the network output vector.
	/// </summary>
	public int ForecastLength
	{
		get => _forecastLength.Value;
		set => _forecastLength.Value = value;
	}

	/// <summary>
	/// Candle type used to gather prices.
	/// </summary>
	public DataType CandleType
	{
		get => _candleType.Value;
		set => _candleType.Value = value;
	}

	/// <summary>
	/// Maximum number of Hopfield iterations executed per forecast.
	/// </summary>
	public int MaxIterations
	{
		get => _maxIterations.Value;
		set => _maxIterations.Value = value;
	}

	/// <summary>
	/// Desired accuracy when checking convergence of neuron states.
	/// </summary>
	public double Accuracy
	{
		get => _accuracy.Value;
		set => _accuracy.Value = value;
	}

	/// <summary>
	/// Enables verbose logging of the neural network state.
	/// </summary>
	public bool EnableDebugLog
	{
		get => _enableDebugLog.Value;
		set => _enableDebugLog.Value = value;
	}

	/// <summary>
	/// Initializes a new instance of the strategy.
	/// </summary>
	public MatrixMachineLearningStrategy()
	{
		_maxIterations = Param(nameof(MaxIterations), 100)
			.SetGreaterThanZero()
			.SetDisplay("Max Iterations", "Maximum number of Hopfield iterations executed per forecast.", "Machine Learning");

		_accuracy = Param(nameof(Accuracy), 0.00001)
			.SetDisplay("Accuracy", "Desired accuracy when checking convergence of neuron states.", "Machine Learning");

		_historyDepth = Param(nameof(HistoryDepth), 120)
			.SetGreaterThanZero()
			.SetDisplay("History Depth", "Total amount of closes stored for the Hopfield network.", "Machine Learning")
			
			.SetOptimize(80, 200, 10);

		_forwardDepth = Param(nameof(ForwardDepth), 60)
			.SetGreaterThanZero()
			.SetDisplay("Forward Depth", "Amount of closes kept for out-of-sample validation.", "Machine Learning")
			
			.SetOptimize(30, 120, 10);

		_predictorLength = Param(nameof(PredictorLength), 20)
			.SetGreaterThanZero()
			.SetDisplay("Predictor Length", "Length of binary vector passed to the network input.", "Machine Learning")
			
			.SetOptimize(10, 40, 2);

		_forecastLength = Param(nameof(ForecastLength), 10)
			.SetGreaterThanZero()
			.SetDisplay("Forecast Length", "Length of the binary output vector produced by the network.", "Machine Learning")
			
			.SetOptimize(5, 20, 1);

		_candleType = Param(nameof(CandleType), TimeSpan.FromMinutes(60).TimeFrame())
			.SetDisplay("Candle Type", "Type of candles requested from the market data source.", "Data");

		_enableDebugLog = Param(nameof(EnableDebugLog), false)
			.SetDisplay("Debug Log", "Write detailed neural network diagnostics to the log.", "Machine Learning");
	}

	/// <inheritdoc />
	public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
		=> [(Security, CandleType)];

	/// <inheritdoc />
	protected override void OnReseted()
	{
		base.OnReseted();

		_closes.Clear();
		_weights = null;
	}

	/// <inheritdoc />
	protected override void OnStarted2(DateTime time)
	{
		base.OnStarted2(time);

		StartProtection(null, null);

		var subscription = SubscribeCandles(CandleType);
		subscription
			.Bind(ProcessCandle)
			.Start();

		var area = CreateChartArea();
		if (area != null)
		{
			DrawCandles(area, subscription);
			DrawOwnTrades(area);
		}
	}

	private void ProcessCandle(ICandleMessage candle)
	{
		if (candle.State != CandleStates.Finished)
			return;

		_closes.Add(candle.ClosePrice);
		if (_closes.Count > HistoryDepth)
			_closes.RemoveAt(0);

		if (_closes.Count < PredictorLength + ForecastLength + 1)
			return;

		if (_closes.Count < ForwardDepth + 2)
			return;

		var closes = _closes.ToArray();

		TrainNetwork(closes);

		var forecast = Forecast(closes);
		if (forecast == null || forecast.Length == 0)
			return;

		var direction = forecast.Sum();
		if (direction > 0 && Position <= 0m)
		{
			BuyMarket(Position < 0m ? Math.Abs(Position) + 1 : 1);
		}
		else if (direction < 0 && Position >= 0m)
		{
			SellMarket(Position > 0m ? Math.Abs(Position) + 1 : 1);
		}
	}

	private void TrainNetwork(IReadOnlyList<decimal> closes)
	{
		var historyCount = closes.Count;
		var forwardCount = Math.Min(ForwardDepth, historyCount - 1);
		var trainingCount = historyCount - forwardCount;

		if (trainingCount <= PredictorLength + ForecastLength)
			return;

		var trainingData = BuildBinaryDiff(closes, 0, trainingCount);
		if (trainingData.Length < PredictorLength + ForecastLength)
			return;

		var weights = TrainWeights(trainingData, PredictorLength, ForecastLength);
		if (weights == null)
			return;

		_weights = weights;

		EvaluateWeights(trainingData, "Backtest evaluation");

		var forwardData = BuildBinaryDiff(closes, trainingCount - 1, forwardCount + 1);
		if (forwardData.Length >= PredictorLength + ForecastLength)
			EvaluateWeights(forwardData, "Forward evaluation");
	}

	private double[] Forecast(IReadOnlyList<decimal> closes)
	{
		var weights = _weights;
		if (weights == null)
			return null;

		var pattern = BuildCurrentPattern(closes);
		if (pattern == null)
			return null;

		var forecast = RunWeights(weights, pattern);

		if (EnableDebugLog)
		{
			LogInfo(FormattableString.Invariant($"Online pattern: {FormatVector(pattern)}"));
			LogInfo(FormattableString.Invariant($"Forecast: {FormatVector(forecast)}"));
		}

		return forecast;
	}

	private static double[] BuildBinaryDiff(IReadOnlyList<decimal> closes, int startIndex, int length)
	{
		if (length <= 1 || startIndex < 0)
			return Array.Empty<double>();

		if (startIndex + length > closes.Count)
			length = closes.Count - startIndex;

		var resultLength = length - 1;
		if (resultLength <= 0)
			return Array.Empty<double>();

		var result = new double[resultLength];
		for (var i = 0; i < resultLength; i++)
		{
			var first = closes[startIndex + i];
			var second = closes[startIndex + i + 1];
			var diff = (double)(second - first);
			result[i] = diff >= 0 ? 1d : -1d;
		}

		return result;
	}

	private double[] BuildCurrentPattern(IReadOnlyList<decimal> closes)
	{
		var required = PredictorLength + 1;
		if (closes.Count < required)
			return null;

		var startIndex = closes.Count - required;
		var pattern = new double[PredictorLength];
		for (var i = 0; i < PredictorLength; i++)
		{
			var first = closes[startIndex + i];
			var second = closes[startIndex + i + 1];
			var diff = (double)(second - first);
			pattern[i] = diff >= 0 ? 1d : -1d;
		}

		return pattern;
	}

	private static double[,] TrainWeights(double[] data, int predictor, int response)
	{
		var sample = predictor + response;
		if (data.Length < sample)
			return null;

		var count = data.Length - sample + 1;
		var weights = new double[predictor, response];

		for (var index = 0; index < count; index++)
		{
			for (var row = 0; row < predictor; row++)
			{
				var inputValue = data[index + row];
				for (var column = 0; column < response; column++)
				{
					var outputValue = data[index + predictor + column];
					weights[row, column] += inputValue * outputValue;
				}
			}
		}

		return weights;
	}

	private void EvaluateWeights(double[] data, string title)
	{
		var weights = _weights;
		if (weights == null)
			return;

		var predictor = weights.GetLength(0);
		var response = weights.GetLength(1);
		var sample = predictor + response;

		if (data.Length < sample)
			return;

		var count = data.Length - sample + 1;
		if (count <= 0)
			return;

		var positive = 0;
		var negative = 0;
		double sum = 0;

		for (var index = 0; index < count; index++)
		{
			var input = new double[predictor];
			var target = new double[response];

			for (var i = 0; i < predictor; i++)
				input[i] = data[index + i];

			for (var i = 0; i < response; i++)
				target[i] = data[index + predictor + i];

			var forecast = RunWeights(weights, input);

			double match = 0;
			for (var i = 0; i < response; i++)
				match += forecast[i] * target[i];

			if (match > 0)
				positive++;
			else if (match < 0)
				negative++;

			sum += match;

			if (EnableDebugLog)
			{
				LogInfo(FormattableString.Invariant($"Sample {index}: forecast={FormatVector(forecast)} target={FormatVector(target)} match={match:0.###}"));
			}
		}

		var average = sum / count;
		var accuracy = (average + response) / (2.0 * response) * 100.0;

		LogInfo(FormattableString.Invariant($"{title}: count={count} positive={positive} negative={negative} accuracy={accuracy:0.##}%"));
	}

	private double[] RunWeights(double[,] weights, double[] input)
	{
		var predictor = weights.GetLength(0);
		var response = weights.GetLength(1);
		var forecast = new double[response];

		if (input.Length != predictor)
			return forecast;

		var a = new double[predictor];
		var b = new double[response];

		for (var i = 0; i < predictor; i++)
			a[i] = input[i];

		for (var iteration = 0; iteration < MaxIterations; iteration++)
		{
			var previousA = new double[predictor];
			var previousB = new double[response];

			for (var i = 0; i < predictor; i++)
				previousA[i] = a[i];

			for (var i = 0; i < response; i++)
				previousB[i] = b[i];

			for (var column = 0; column < response; column++)
			{
				double sum = 0;
				for (var row = 0; row < predictor; row++)
					sum += a[row] * weights[row, column];
				b[column] = Math.Tanh(sum);
			}

			for (var row = 0; row < predictor; row++)
			{
				double sum = 0;
				for (var column = 0; column < response; column++)
					sum += b[column] * weights[row, column];
				a[row] = Math.Tanh(sum);
			}

			var diffA = 0d;
			for (var i = 0; i < predictor; i++)
			{
				var delta = Math.Abs(a[i] - previousA[i]);
				if (delta > diffA)
					diffA = delta;
			}

			var diffB = 0d;
			for (var i = 0; i < response; i++)
			{
				var delta = Math.Abs(b[i] - previousB[i]);
				if (delta > diffB)
					diffB = delta;
			}

			if (diffA < Accuracy && diffB < Accuracy)
				break;
		}

		for (var i = 0; i < response; i++)
			forecast[i] = b[i] >= 0 ? 1d : -1d;

		return forecast;
	}

	private static string FormatVector(IReadOnlyList<double> values)
	{
		var builder = new StringBuilder();
		builder.Append('[');
		for (var i = 0; i < values.Count; i++)
		{
			builder.Append(FormattableString.Invariant($"{values[i]:0.###}"));
			if (i + 1 < values.Count)
				builder.Append(',');
		}
		builder.Append(']');
		return builder.ToString();
	}
}