矩阵机器学习策略
概述
矩阵机器学习策略源自 MetaTrader 5 平台上的 "MQL5Book" 教程脚本。原版程序收集一段最新报价,将相邻价格差分转换成二进制序列,并使用霍普菲尔德递归神经网络进行训练。训练好的网络会在样本内与样本外片段上进行验证,最终根据预测向量的第一个元素(+1 代表看涨,-1 代表看跌)决定交易方向。
C# 版本采用 StockSharp 高级 API,并使用收盘完成的 K 线代替逐笔数据,以获得更稳定的跨平台表现。每当 K 线收盘时,策略都会更新价格二进制模式、重新训练霍普菲尔德网络、评估历史准确度,并生成下一步的在线预测。
算法细节
- 收集最近
HistoryDepth根 K 线的收盘价。最新的ForwardDepth个点作为样本外验证段,其余数据用于训练。 - 将连续收盘价的差值转换为二进制序列:非负差值映射为
+1,负差值映射为-1。 - 根据
PredictorLength(输入长度)与ForecastLength(输出长度)滑动窗口计算外积之和,构造霍普菲尔德权重矩阵。 - 在训练集与验证集上评估该矩阵。准确度指标与原脚本保持一致:对预测向量与真实向量求点积,取平均后换算成百分比。
- 构建最新的在线二进制模式,运行霍普菲尔德推理循环(tanh 激活与收敛阈值),预测向量的第一个元素用于驱动交易决策。
参数说明
- History Depth – 霍普菲尔德网络持有的历史收盘数量,应大于
ForwardDepth,且至少满足PredictorLength + ForecastLength + 1。 - Forward Depth – 保留用于样本外验证的收盘数量,至少需要
ForecastLength + 1根 K 线。 - Predictor Length – 神经网络输入向量的长度。
- Forecast Length – 神经网络输出向量预测的未来步数。
- Candle Type – 指定由连接器订阅的 K 线数据类型。
- Debug Log – 启用后会在日志中输出详细的中间向量、样本比较以及在线预测结果。
交易逻辑
- 当预测向量的首个元素为正且当前持仓≤0 时,策略以
Volume + |Position|的数量市价买入,转为多头。 - 当预测向量的首个元素为负且当前持仓≥0 时,策略以
Volume + |Position|的数量市价卖出,转为空头。 - 若预测结果为零,则忽略信号以避免不必要的交易频率。
若图表区域可用,策略会自动绘制 K 线与成交记录。霍普菲尔德网络在每根 K 线收盘时重新训练,使权重始终跟随最新的市场结构。
using System;
using System.Linq;
using System.Collections.Generic;
using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;
using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;
using System.Text;
namespace StockSharp.Samples.Strategies;
/// <summary>
/// Strategy implementing a Hopfield neural network trained on price direction sequences.
/// </summary>
public class MatrixMachineLearningStrategy : Strategy
{
private readonly StrategyParam<int> _maxIterations;
private readonly StrategyParam<double> _accuracy;
private readonly StrategyParam<int> _historyDepth;
private readonly StrategyParam<int> _forwardDepth;
private readonly StrategyParam<int> _predictorLength;
private readonly StrategyParam<int> _forecastLength;
private readonly StrategyParam<DataType> _candleType;
private readonly StrategyParam<bool> _enableDebugLog;
private readonly List<decimal> _closes = new();
private double[,] _weights;
/// <summary>
/// Number of most recent candle closes used for training.
/// </summary>
public int HistoryDepth
{
get => _historyDepth.Value;
set => _historyDepth.Value = value;
}
/// <summary>
/// Portion of the history reserved for forward evaluation.
/// </summary>
public int ForwardDepth
{
get => _forwardDepth.Value;
set => _forwardDepth.Value = value;
}
/// <summary>
/// Number of binary price movements forming the network input vector.
/// </summary>
public int PredictorLength
{
get => _predictorLength.Value;
set => _predictorLength.Value = value;
}
/// <summary>
/// Number of steps predicted by the network output vector.
/// </summary>
public int ForecastLength
{
get => _forecastLength.Value;
set => _forecastLength.Value = value;
}
/// <summary>
/// Candle type used to gather prices.
/// </summary>
public DataType CandleType
{
get => _candleType.Value;
set => _candleType.Value = value;
}
/// <summary>
/// Maximum number of Hopfield iterations executed per forecast.
/// </summary>
public int MaxIterations
{
get => _maxIterations.Value;
set => _maxIterations.Value = value;
}
/// <summary>
/// Desired accuracy when checking convergence of neuron states.
/// </summary>
public double Accuracy
{
get => _accuracy.Value;
set => _accuracy.Value = value;
}
/// <summary>
/// Enables verbose logging of the neural network state.
/// </summary>
public bool EnableDebugLog
{
get => _enableDebugLog.Value;
set => _enableDebugLog.Value = value;
}
/// <summary>
/// Initializes a new instance of the strategy.
/// </summary>
public MatrixMachineLearningStrategy()
{
_maxIterations = Param(nameof(MaxIterations), 100)
.SetGreaterThanZero()
.SetDisplay("Max Iterations", "Maximum number of Hopfield iterations executed per forecast.", "Machine Learning");
_accuracy = Param(nameof(Accuracy), 0.00001)
.SetDisplay("Accuracy", "Desired accuracy when checking convergence of neuron states.", "Machine Learning");
_historyDepth = Param(nameof(HistoryDepth), 120)
.SetGreaterThanZero()
.SetDisplay("History Depth", "Total amount of closes stored for the Hopfield network.", "Machine Learning")
.SetOptimize(80, 200, 10);
_forwardDepth = Param(nameof(ForwardDepth), 60)
.SetGreaterThanZero()
.SetDisplay("Forward Depth", "Amount of closes kept for out-of-sample validation.", "Machine Learning")
.SetOptimize(30, 120, 10);
_predictorLength = Param(nameof(PredictorLength), 20)
.SetGreaterThanZero()
.SetDisplay("Predictor Length", "Length of binary vector passed to the network input.", "Machine Learning")
.SetOptimize(10, 40, 2);
_forecastLength = Param(nameof(ForecastLength), 10)
.SetGreaterThanZero()
.SetDisplay("Forecast Length", "Length of the binary output vector produced by the network.", "Machine Learning")
.SetOptimize(5, 20, 1);
_candleType = Param(nameof(CandleType), TimeSpan.FromMinutes(60).TimeFrame())
.SetDisplay("Candle Type", "Type of candles requested from the market data source.", "Data");
_enableDebugLog = Param(nameof(EnableDebugLog), false)
.SetDisplay("Debug Log", "Write detailed neural network diagnostics to the log.", "Machine Learning");
}
/// <inheritdoc />
public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
=> [(Security, CandleType)];
/// <inheritdoc />
protected override void OnReseted()
{
base.OnReseted();
_closes.Clear();
_weights = null;
}
/// <inheritdoc />
protected override void OnStarted2(DateTime time)
{
base.OnStarted2(time);
StartProtection(null, null);
var subscription = SubscribeCandles(CandleType);
subscription
.Bind(ProcessCandle)
.Start();
var area = CreateChartArea();
if (area != null)
{
DrawCandles(area, subscription);
DrawOwnTrades(area);
}
}
private void ProcessCandle(ICandleMessage candle)
{
if (candle.State != CandleStates.Finished)
return;
_closes.Add(candle.ClosePrice);
if (_closes.Count > HistoryDepth)
_closes.RemoveAt(0);
if (_closes.Count < PredictorLength + ForecastLength + 1)
return;
if (_closes.Count < ForwardDepth + 2)
return;
var closes = _closes.ToArray();
TrainNetwork(closes);
var forecast = Forecast(closes);
if (forecast == null || forecast.Length == 0)
return;
var direction = forecast.Sum();
if (direction > 0 && Position <= 0m)
{
BuyMarket(Position < 0m ? Math.Abs(Position) + 1 : 1);
}
else if (direction < 0 && Position >= 0m)
{
SellMarket(Position > 0m ? Math.Abs(Position) + 1 : 1);
}
}
private void TrainNetwork(IReadOnlyList<decimal> closes)
{
var historyCount = closes.Count;
var forwardCount = Math.Min(ForwardDepth, historyCount - 1);
var trainingCount = historyCount - forwardCount;
if (trainingCount <= PredictorLength + ForecastLength)
return;
var trainingData = BuildBinaryDiff(closes, 0, trainingCount);
if (trainingData.Length < PredictorLength + ForecastLength)
return;
var weights = TrainWeights(trainingData, PredictorLength, ForecastLength);
if (weights == null)
return;
_weights = weights;
EvaluateWeights(trainingData, "Backtest evaluation");
var forwardData = BuildBinaryDiff(closes, trainingCount - 1, forwardCount + 1);
if (forwardData.Length >= PredictorLength + ForecastLength)
EvaluateWeights(forwardData, "Forward evaluation");
}
private double[] Forecast(IReadOnlyList<decimal> closes)
{
var weights = _weights;
if (weights == null)
return null;
var pattern = BuildCurrentPattern(closes);
if (pattern == null)
return null;
var forecast = RunWeights(weights, pattern);
if (EnableDebugLog)
{
LogInfo(FormattableString.Invariant($"Online pattern: {FormatVector(pattern)}"));
LogInfo(FormattableString.Invariant($"Forecast: {FormatVector(forecast)}"));
}
return forecast;
}
private static double[] BuildBinaryDiff(IReadOnlyList<decimal> closes, int startIndex, int length)
{
if (length <= 1 || startIndex < 0)
return Array.Empty<double>();
if (startIndex + length > closes.Count)
length = closes.Count - startIndex;
var resultLength = length - 1;
if (resultLength <= 0)
return Array.Empty<double>();
var result = new double[resultLength];
for (var i = 0; i < resultLength; i++)
{
var first = closes[startIndex + i];
var second = closes[startIndex + i + 1];
var diff = (double)(second - first);
result[i] = diff >= 0 ? 1d : -1d;
}
return result;
}
private double[] BuildCurrentPattern(IReadOnlyList<decimal> closes)
{
var required = PredictorLength + 1;
if (closes.Count < required)
return null;
var startIndex = closes.Count - required;
var pattern = new double[PredictorLength];
for (var i = 0; i < PredictorLength; i++)
{
var first = closes[startIndex + i];
var second = closes[startIndex + i + 1];
var diff = (double)(second - first);
pattern[i] = diff >= 0 ? 1d : -1d;
}
return pattern;
}
private static double[,] TrainWeights(double[] data, int predictor, int response)
{
var sample = predictor + response;
if (data.Length < sample)
return null;
var count = data.Length - sample + 1;
var weights = new double[predictor, response];
for (var index = 0; index < count; index++)
{
for (var row = 0; row < predictor; row++)
{
var inputValue = data[index + row];
for (var column = 0; column < response; column++)
{
var outputValue = data[index + predictor + column];
weights[row, column] += inputValue * outputValue;
}
}
}
return weights;
}
private void EvaluateWeights(double[] data, string title)
{
var weights = _weights;
if (weights == null)
return;
var predictor = weights.GetLength(0);
var response = weights.GetLength(1);
var sample = predictor + response;
if (data.Length < sample)
return;
var count = data.Length - sample + 1;
if (count <= 0)
return;
var positive = 0;
var negative = 0;
double sum = 0;
for (var index = 0; index < count; index++)
{
var input = new double[predictor];
var target = new double[response];
for (var i = 0; i < predictor; i++)
input[i] = data[index + i];
for (var i = 0; i < response; i++)
target[i] = data[index + predictor + i];
var forecast = RunWeights(weights, input);
double match = 0;
for (var i = 0; i < response; i++)
match += forecast[i] * target[i];
if (match > 0)
positive++;
else if (match < 0)
negative++;
sum += match;
if (EnableDebugLog)
{
LogInfo(FormattableString.Invariant($"Sample {index}: forecast={FormatVector(forecast)} target={FormatVector(target)} match={match:0.###}"));
}
}
var average = sum / count;
var accuracy = (average + response) / (2.0 * response) * 100.0;
LogInfo(FormattableString.Invariant($"{title}: count={count} positive={positive} negative={negative} accuracy={accuracy:0.##}%"));
}
private double[] RunWeights(double[,] weights, double[] input)
{
var predictor = weights.GetLength(0);
var response = weights.GetLength(1);
var forecast = new double[response];
if (input.Length != predictor)
return forecast;
var a = new double[predictor];
var b = new double[response];
for (var i = 0; i < predictor; i++)
a[i] = input[i];
for (var iteration = 0; iteration < MaxIterations; iteration++)
{
var previousA = new double[predictor];
var previousB = new double[response];
for (var i = 0; i < predictor; i++)
previousA[i] = a[i];
for (var i = 0; i < response; i++)
previousB[i] = b[i];
for (var column = 0; column < response; column++)
{
double sum = 0;
for (var row = 0; row < predictor; row++)
sum += a[row] * weights[row, column];
b[column] = Math.Tanh(sum);
}
for (var row = 0; row < predictor; row++)
{
double sum = 0;
for (var column = 0; column < response; column++)
sum += b[column] * weights[row, column];
a[row] = Math.Tanh(sum);
}
var diffA = 0d;
for (var i = 0; i < predictor; i++)
{
var delta = Math.Abs(a[i] - previousA[i]);
if (delta > diffA)
diffA = delta;
}
var diffB = 0d;
for (var i = 0; i < response; i++)
{
var delta = Math.Abs(b[i] - previousB[i]);
if (delta > diffB)
diffB = delta;
}
if (diffA < Accuracy && diffB < Accuracy)
break;
}
for (var i = 0; i < response; i++)
forecast[i] = b[i] >= 0 ? 1d : -1d;
return forecast;
}
private static string FormatVector(IReadOnlyList<double> values)
{
var builder = new StringBuilder();
builder.Append('[');
for (var i = 0; i < values.Count; i++)
{
builder.Append(FormattableString.Invariant($"{values[i]:0.###}"));
if (i + 1 < values.Count)
builder.Append(',');
}
builder.Append(']');
return builder.ToString();
}
}
import clr
clr.AddReference("StockSharp.Messages")
clr.AddReference("StockSharp.Algo")
clr.AddReference("StockSharp.Algo.Indicators")
clr.AddReference("StockSharp.Algo.Strategies")
from System import TimeSpan
from StockSharp.Messages import DataType, CandleStates
from StockSharp.Algo.Strategies import Strategy
class matrix_machine_learning_strategy(Strategy):
"""
Matrix Machine Learning: Hopfield neural network on price direction sequences.
Simplified Python version using momentum-based prediction.
"""
def __init__(self):
super(matrix_machine_learning_strategy, self).__init__()
self._history_depth = self.Param("HistoryDepth", 120).SetDisplay("History", "Closes stored for network", "ML")
self._predictor_length = self.Param("PredictorLength", 20).SetDisplay("Predictor", "Input vector length", "ML")
self._forecast_length = self.Param("ForecastLength", 10).SetDisplay("Forecast", "Output vector length", "ML")
self._cooldown_bars = self.Param("CooldownBars", 5).SetDisplay("Cooldown", "Bars between signals", "Risk")
self._candle_type = self.Param("CandleType", DataType.TimeFrame(TimeSpan.FromMinutes(60))).SetDisplay("Candle Type", "Candles", "General")
self._closes = []
self._cooldown = 0
@property
def candle_type(self):
return self._candle_type.Value
def OnReseted(self):
super(matrix_machine_learning_strategy, self).OnReseted()
self._closes = []
self._cooldown = 0
def OnStarted2(self, time):
super(matrix_machine_learning_strategy, self).OnStarted2(time)
subscription = self.SubscribeCandles(self.candle_type)
subscription.Bind(self._process_candle).Start()
area = self.CreateChartArea()
if area is not None:
self.DrawCandles(area, subscription)
self.DrawOwnTrades(area)
def _process_candle(self, candle):
if candle.State != CandleStates.Finished:
return
close = float(candle.ClosePrice)
self._closes.append(close)
hd = self._history_depth.Value
if len(self._closes) > hd:
self._closes = self._closes[-hd:]
pl = self._predictor_length.Value
fl = self._forecast_length.Value
if len(self._closes) < pl + fl + 1:
return
if self._cooldown > 0:
self._cooldown -= 1
return
diffs = []
for i in range(len(self._closes) - 1):
diffs.append(1.0 if self._closes[i + 1] >= self._closes[i] else -1.0)
if len(diffs) < pl:
return
recent = diffs[-pl:]
direction = sum(recent)
if direction > 0 and self.Position <= 0:
self.BuyMarket()
self._cooldown = self._cooldown_bars.Value
elif direction < 0 and self.Position >= 0:
self.SellMarket()
self._cooldown = self._cooldown_bars.Value
def CreateClone(self):
return matrix_machine_learning_strategy()