神经网络模板策略
概述
该策略来源于 MQL5 神经网络模板专家顾问。由于 StockSharp 环境缺少原始项目中的自定义网络加载器,我们使用可解释的得分模型来替代黑盒网络,同时保持原策略的结构、指标和风险控制。目标是在 RSI 与 MACD 指向同一方向且预计波动幅度足够时捕捉动能行情。
指标与数据
- 相对强弱指标 RSI(12 周期) 基于收盘价计算,对应原模板的典型价格输入。
- MACD 指标(12/48/12) 利用直方图作为动能与置信度衡量。
- 时间框架 可配置,默认使用 5 分钟 K 线,与原始专家设置保持一致。
交易逻辑
- 每根收盘的 K 线都会更新 RSI 和 MACD 直方图的滚动队列,队列长度由
BarsToPattern控制。 - RSI 偏离 50 水平以及 MACD 直方图相对滚动均值的偏差会被组合成置信度分数,并通过双曲正切函数进行压缩,模拟原神经网络的输出。
- 当置信度绝对值超过
TradeLevel且预测幅度(转换为点数)大于MinTargetPoints时,策略按照分数指示的方向下市价单。 - 预测幅度乘以
ProfitMultiply后得到的动态止盈价会被限制在MaxTakeProfitPoints之内,并与对称的止损点位一起保存,供手动退出管理使用。 - 持仓期间每根收盘 K 线都会检查价格是否触及止盈或止损,一旦达到即市价平仓并重置内部状态。
参数
| 参数 | 说明 |
|---|---|
BarsToPattern |
用于计算 RSI 与 MACD 统计的滚动窗口长度。 |
TradeLevel |
开仓前所需的最小置信度(0-1)。 |
ProfitMultiply |
在限制 MaxTakeProfitPoints 之前应用于预测幅度的倍数。 |
MinTargetPoints |
进入交易所需的最小预测点数。 |
MaxTakeProfitPoints |
止盈的最大允许点数。 |
StopLossPoints |
相对入场价的固定止损距离。 |
TradeVolume |
每次市价单的下单手数。 |
CandleType |
策略订阅的 K 线类型或时间框架。 |
说明
- 采用确定性置信度模型,便于调试与复现,同时保留原神经网络策略的整体结构。
- 为了复刻原模板的动态止盈/止损逻辑,策略使用内部变量管理每笔交易的目标价位与止损价位。
- 策略只在没有持仓时才评估新的入场信号,与原 MQL5 专家顾问的单仓模式一致。
using System;
using System.Linq;
using System.Collections.Generic;
using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;
using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;
namespace StockSharp.Samples.Strategies;
/// <summary>
/// Momentum strategy inspired by a neural-network driven template from MQL5.
/// </summary>
public class NeuralNetworkTemplateStrategy : Strategy
{
private readonly StrategyParam<int> _barsToPattern;
private readonly StrategyParam<int> _maxTakeProfitPoints;
private readonly StrategyParam<int> _minTargetPoints;
private readonly StrategyParam<int> _stopLossPoints;
private readonly StrategyParam<decimal> _profitMultiply;
private readonly StrategyParam<decimal> _tradeLevel;
private readonly StrategyParam<decimal> _volume;
private readonly StrategyParam<DataType> _candleType;
private RelativeStrengthIndex _rsi = null!;
private MovingAverageConvergenceDivergenceSignal _macd = null!;
private readonly Queue<decimal> _rsiHistory = new();
private readonly Queue<decimal> _macdHistory = new();
private decimal _rsiSum;
private decimal _macdSum;
private decimal? _targetPrice;
private decimal? _stopPrice;
private int _positionDirection;
/// <summary>
/// Number of candles used for pattern recognition.
/// </summary>
public int BarsToPattern
{
get => _barsToPattern.Value;
set => _barsToPattern.Value = value;
}
/// <summary>
/// Upper bound for calculated take-profit in points.
/// </summary>
public int MaxTakeProfitPoints
{
get => _maxTakeProfitPoints.Value;
set => _maxTakeProfitPoints.Value = value;
}
/// <summary>
/// Minimum projected move required to open a trade.
/// </summary>
public int MinTargetPoints
{
get => _minTargetPoints.Value;
set => _minTargetPoints.Value = value;
}
/// <summary>
/// Stop-loss distance in points.
/// </summary>
public int StopLossPoints
{
get => _stopLossPoints.Value;
set => _stopLossPoints.Value = value;
}
/// <summary>
/// Multiplier applied to the projected move returned by the scoring model.
/// </summary>
public decimal ProfitMultiply
{
get => _profitMultiply.Value;
set => _profitMultiply.Value = value;
}
/// <summary>
/// Required confidence level before opening a new position.
/// </summary>
public decimal TradeLevel
{
get => _tradeLevel.Value;
set => _tradeLevel.Value = value;
}
/// <summary>
/// Trading volume for every market order.
/// </summary>
public decimal TradeVolume
{
get => _volume.Value;
set => _volume.Value = value;
}
/// <summary>
/// Candle type used by the strategy.
/// </summary>
public DataType CandleType
{
get => _candleType.Value;
set => _candleType.Value = value;
}
/// <summary>
/// Initializes a new instance of <see cref="NeuralNetworkTemplateStrategy"/>.
/// </summary>
public NeuralNetworkTemplateStrategy()
{
_barsToPattern = Param(nameof(BarsToPattern), 3)
.SetGreaterThanZero()
.SetDisplay("Bars", "Candles analysed", "Model")
;
_maxTakeProfitPoints = Param(nameof(MaxTakeProfitPoints), 500)
.SetGreaterThanZero()
.SetDisplay("Max TP", "Maximum take-profit in points", "Risk");
_minTargetPoints = Param(nameof(MinTargetPoints), 1)
.SetGreaterThanZero()
.SetDisplay("Min Target", "Minimum projected move in points", "Model");
_stopLossPoints = Param(nameof(StopLossPoints), 300)
.SetGreaterThanZero()
.SetDisplay("Stop-Loss", "Stop-loss distance in points", "Risk");
_profitMultiply = Param(nameof(ProfitMultiply), 0.8m)
.SetNotNegative()
.SetDisplay("Profit Mult", "Take-profit multiplier", "Model");
_tradeLevel = Param(nameof(TradeLevel), 0.1m)
.SetRange(0m, 1m)
.SetDisplay("Trade Level", "Required confidence", "Model");
_volume = Param(nameof(TradeVolume), 0.1m)
.SetGreaterThanZero()
.SetDisplay("Volume", "Order volume", "Trading");
_candleType = Param(nameof(CandleType), TimeSpan.FromMinutes(5).TimeFrame())
.SetDisplay("TF", "Working timeframe", "General");
}
/// <inheritdoc />
public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
{
return [(Security, CandleType)];
}
/// <inheritdoc />
protected override void OnReseted()
{
base.OnReseted();
_rsi = null!;
_macd = null!;
_rsiHistory.Clear();
_macdHistory.Clear();
_rsiSum = 0m;
_macdSum = 0m;
_targetPrice = null;
_stopPrice = null;
_positionDirection = 0;
}
/// <inheritdoc />
protected override void OnStarted2(DateTime time)
{
base.OnStarted2(time);
_rsi = new RelativeStrengthIndex { Length = 12 };
_macd = new MovingAverageConvergenceDivergenceSignal
{
Macd =
{
ShortMa = { Length = 12 },
LongMa = { Length = 48 }
},
SignalMa = { Length = 12 }
};
var subscription = SubscribeCandles(CandleType);
subscription
.BindEx(_rsi, _macd, ProcessCandle)
.Start();
var area = CreateChartArea();
if (area != null)
{
DrawCandles(area, subscription);
DrawIndicator(area, _rsi);
DrawIndicator(area, _macd);
DrawOwnTrades(area);
}
}
private void ProcessCandle(ICandleMessage candle, IIndicatorValue rsiValue, IIndicatorValue macdValue)
{
if (candle.State != CandleStates.Finished)
return;
ManageOpenPosition(candle);
if (!_rsi.IsFormed || !_macd.IsFormed)
return;
var rsiDecimal = rsiValue.ToDecimal();
if (macdValue is not MovingAverageConvergenceDivergenceSignalValue macdComponents)
return;
if (macdComponents.Macd is not decimal macdLine ||
macdComponents.Signal is not decimal signalLine)
return;
UpdateHistory(rsiDecimal, macdLine - signalLine);
if (Position != 0)
return;
EvaluateEntry(candle, rsiDecimal, macdLine, signalLine);
}
private void UpdateHistory(decimal rsiValue, decimal macdHistogram)
{
_rsiHistory.Enqueue(rsiValue);
_rsiSum += rsiValue;
if (_rsiHistory.Count > BarsToPattern)
_rsiSum -= _rsiHistory.Dequeue();
_macdHistory.Enqueue(macdHistogram);
_macdSum += macdHistogram;
if (_macdHistory.Count > BarsToPattern)
_macdSum -= _macdHistory.Dequeue();
}
private void EvaluateEntry(ICandleMessage candle, decimal rsiValue, decimal macdLine, decimal signalLine)
{
if (_rsiHistory.Count < BarsToPattern || _macdHistory.Count < BarsToPattern)
return;
var priceStep = Security?.PriceStep ?? 1m;
if (priceStep <= 0m)
priceStep = 1m;
var normalizedRsi = Clamp((rsiValue - 50m) / 50m, -1m, 1m);
var macdHistogram = macdLine - signalLine;
var macdAverage = _macdHistory.Count == 0 ? 0m : _macdSum / _macdHistory.Count;
var macdDeviation = macdHistogram - macdAverage;
var normalizedMomentum = (decimal)Math.Tanh((double)(macdDeviation * 5m));
var combinedScore = normalizedRsi * 0.6m + normalizedMomentum * 0.4m;
var confidence = Math.Min(1m, Math.Abs(combinedScore));
var projectedMove = macdDeviation * BarsToPattern;
var projectedPoints = projectedMove / priceStep;
if (combinedScore > 0m)
{
if (confidence < TradeLevel)
return;
if (projectedPoints < MinTargetPoints)
return;
var takeProfit = candle.ClosePrice + Math.Min(projectedMove * ProfitMultiply, MaxTakeProfitPoints * priceStep);
var stopLoss = candle.ClosePrice - StopLossPoints * priceStep;
if (takeProfit <= candle.ClosePrice)
return;
BuyMarket(TradeVolume);
_targetPrice = takeProfit;
_stopPrice = stopLoss;
_positionDirection = 1;
}
else if (combinedScore < 0m)
{
if (confidence < TradeLevel)
return;
if (projectedPoints > -MinTargetPoints)
return;
var takeProfit = candle.ClosePrice + Math.Max(projectedMove * ProfitMultiply, -MaxTakeProfitPoints * priceStep);
var stopLoss = candle.ClosePrice + StopLossPoints * priceStep;
if (takeProfit >= candle.ClosePrice)
return;
SellMarket(TradeVolume);
_targetPrice = takeProfit;
_stopPrice = stopLoss;
_positionDirection = -1;
}
}
private void ManageOpenPosition(ICandleMessage candle)
{
if (Position > 0m)
{
if (_stopPrice.HasValue && candle.LowPrice <= _stopPrice.Value)
{
SellMarket(Math.Abs(Position));
ResetTargets();
return;
}
if (_targetPrice.HasValue && candle.HighPrice >= _targetPrice.Value)
{
SellMarket(Math.Abs(Position));
ResetTargets();
}
}
else if (Position < 0m)
{
if (_stopPrice.HasValue && candle.HighPrice >= _stopPrice.Value)
{
BuyMarket(Math.Abs(Position));
ResetTargets();
return;
}
if (_targetPrice.HasValue && candle.LowPrice <= _targetPrice.Value)
{
BuyMarket(Math.Abs(Position));
ResetTargets();
}
}
else if (_positionDirection != 0)
{
ResetTargets();
}
}
private void ResetTargets()
{
_targetPrice = null;
_stopPrice = null;
_positionDirection = 0;
}
private static decimal Clamp(decimal value, decimal min, decimal max)
{
if (value < min)
return min;
if (value > max)
return max;
return value;
}
}
import clr
clr.AddReference("StockSharp.Messages")
clr.AddReference("StockSharp.Algo")
clr.AddReference("StockSharp.Algo.Indicators")
clr.AddReference("StockSharp.Algo.Strategies")
from System import TimeSpan, Math
from StockSharp.Messages import DataType, CandleStates
from StockSharp.Algo.Indicators import RelativeStrengthIndex, MovingAverageConvergenceDivergenceSignal
from StockSharp.Algo.Strategies import Strategy
from collections import deque
from datatype_extensions import *
from indicator_extensions import *
import math
class neural_network_template_strategy(Strategy):
def __init__(self):
super(neural_network_template_strategy, self).__init__()
self._bars_to_pattern = self.Param("BarsToPattern", 3).SetGreaterThanZero().SetDisplay("Bars", "Candles analysed", "Model")
self._max_tp_points = self.Param("MaxTakeProfitPoints", 500).SetGreaterThanZero().SetDisplay("Max TP", "Maximum take-profit in points", "Risk")
self._min_target = self.Param("MinTargetPoints", 1).SetGreaterThanZero().SetDisplay("Min Target", "Minimum projected move in points", "Model")
self._sl_points = self.Param("StopLossPoints", 300).SetGreaterThanZero().SetDisplay("Stop-Loss", "Stop-loss distance in points", "Risk")
self._profit_mult = self.Param("ProfitMultiply", 0.8).SetNotNegative().SetDisplay("Profit Mult", "Take-profit multiplier", "Model")
self._trade_level = self.Param("TradeLevel", 0.1).SetDisplay("Trade Level", "Required confidence", "Model")
self._trade_volume = self.Param("TradeVolume", 0.1).SetGreaterThanZero().SetDisplay("Volume", "Order volume", "Trading")
self._candle_type = self.Param("CandleType", tf(5)).SetDisplay("TF", "Working timeframe", "General")
@property
def CandleType(self): return self._candle_type.Value
@CandleType.setter
def CandleType(self, value): self._candle_type.Value = value
def OnReseted(self):
super(neural_network_template_strategy, self).OnReseted()
self._rsi_history = deque()
self._macd_history = deque()
self._rsi_sum = 0
self._macd_sum = 0
self._target_price = None
self._stop_price = None
self._pos_dir = 0
def OnStarted2(self, time):
super(neural_network_template_strategy, self).OnStarted2(time)
self._rsi_history = deque()
self._macd_history = deque()
self._rsi_sum = 0
self._macd_sum = 0
self._target_price = None
self._stop_price = None
self._pos_dir = 0
self._rsi = RelativeStrengthIndex()
self._rsi.Length = 12
self._macd = MovingAverageConvergenceDivergenceSignal()
self._macd.Macd.ShortMa.Length = 12
self._macd.Macd.LongMa.Length = 48
self._macd.SignalMa.Length = 12
sub = self.SubscribeCandles(self.CandleType)
sub.BindEx(self._rsi, self._macd, self.OnProcess).Start()
area = self.CreateChartArea()
if area is not None:
self.DrawCandles(area, sub)
self.DrawIndicator(area, self._rsi)
self.DrawOwnTrades(area)
def OnProcess(self, candle, rsi_value, macd_value):
if candle.State != CandleStates.Finished:
return
self._manage_position(candle)
if not self._rsi.IsFormed or not self._macd.IsFormed:
return
rsi_dec = float(rsi_value)
macd_line = macd_value.Macd
signal_line = macd_value.Signal
if macd_line is None or signal_line is None:
return
macd_line = float(macd_line)
signal_line = float(signal_line)
histogram = macd_line - signal_line
self._update_history(rsi_dec, histogram)
if self.Position != 0:
return
self._evaluate_entry(candle, rsi_dec, macd_line, signal_line)
def _update_history(self, rsi_val, macd_hist):
bars = self._bars_to_pattern.Value
self._rsi_history.append(rsi_val)
self._rsi_sum += rsi_val
if len(self._rsi_history) > bars:
self._rsi_sum -= self._rsi_history.popleft()
self._macd_history.append(macd_hist)
self._macd_sum += macd_hist
if len(self._macd_history) > bars:
self._macd_sum -= self._macd_history.popleft()
def _evaluate_entry(self, candle, rsi_val, macd_line, signal_line):
bars = self._bars_to_pattern.Value
if len(self._rsi_history) < bars or len(self._macd_history) < bars:
return
price_step = 1.0
if self.Security is not None and self.Security.PriceStep is not None and self.Security.PriceStep > 0:
price_step = float(self.Security.PriceStep)
normalized_rsi = max(-1, min(1, (rsi_val - 50) / 50))
macd_hist = macd_line - signal_line
macd_avg = self._macd_sum / len(self._macd_history) if len(self._macd_history) > 0 else 0
macd_dev = macd_hist - macd_avg
normalized_momentum = math.tanh(macd_dev * 5)
combined = normalized_rsi * 0.6 + normalized_momentum * 0.4
confidence = min(1, abs(combined))
projected_move = macd_dev * bars
projected_points = projected_move / price_step
vol = self._trade_volume.Value
if combined > 0:
if confidence < self._trade_level.Value:
return
if projected_points < self._min_target.Value:
return
tp = candle.ClosePrice + min(projected_move * self._profit_mult.Value, self._max_tp_points.Value * price_step)
sl = candle.ClosePrice - self._sl_points.Value * price_step
if tp <= candle.ClosePrice:
return
self.BuyMarket(vol)
self._target_price = tp
self._stop_price = sl
self._pos_dir = 1
elif combined < 0:
if confidence < self._trade_level.Value:
return
if projected_points > -self._min_target.Value:
return
tp = candle.ClosePrice + max(projected_move * self._profit_mult.Value, -self._max_tp_points.Value * price_step)
sl = candle.ClosePrice + self._sl_points.Value * price_step
if tp >= candle.ClosePrice:
return
self.SellMarket(vol)
self._target_price = tp
self._stop_price = sl
self._pos_dir = -1
def _manage_position(self, candle):
if self.Position > 0:
if self._stop_price is not None and candle.LowPrice <= self._stop_price:
self.SellMarket(Math.Abs(self.Position))
self._reset_targets()
return
if self._target_price is not None and candle.HighPrice >= self._target_price:
self.SellMarket(Math.Abs(self.Position))
self._reset_targets()
elif self.Position < 0:
if self._stop_price is not None and candle.HighPrice >= self._stop_price:
self.BuyMarket(Math.Abs(self.Position))
self._reset_targets()
return
if self._target_price is not None and candle.LowPrice <= self._target_price:
self.BuyMarket(Math.Abs(self.Position))
self._reset_targets()
elif self._pos_dir != 0:
self._reset_targets()
def _reset_targets(self):
self._target_price = None
self._stop_price = None
self._pos_dir = 0
def CreateClone(self):
return neural_network_template_strategy()