RNN Probability 策略
概述
RNN Probability 策略移植自 MetaTrader 专家顾问 RNN (barabashkakvn's edition)。原始算法会采集三个间隔等于 RSI 周期的 RSI 数值,并将其送入一个手工构建的概率网络,模拟递归神经网络的判定逻辑。StockSharp 版本通过高级别的蜡烛订阅复现这一流程,并自动将 MetaTrader 中的手数、点值以及止损/止盈距离转换成 StockSharp 的概念。
当最新完成的蜡烛生成 RSI 值后,策略会回溯一倍和两倍 RSI 周期的历史 RSI 数据。三个归一化的数值与八个权重(Weight0 … Weight7)组合,得出市场下行的概率。该概率被线性映射到 [-1; 1] 区间,符号决定做多还是做空。策略始终只持有一个方向的净头寸,与原始 EA 保持一致。
交易逻辑
- 订阅所选蜡烛类型,并使用
AppliedPrice(默认取开盘价)作为输入手动推进RelativeStrengthIndex指标。 - 将完成的 RSI 数值保存在一个滚动缓冲区,确保可以访问一倍和两倍周期之前的数据。
- 将三个 RSI 数值归一化到
[0; 1]范围后计算概率网络:- 当当前 RSI 处于区间下半部分(低于 50)时,使用
Weight0、Weight1、Weight2、Weight3组合。 - 当当前 RSI 位于区间上半部分时,使用
Weight4、Weight5、Weight6、Weight7组合。
- 当当前 RSI 处于区间下半部分(低于 50)时,使用
- 将得到的概率转换为
-1到+1之间的信号。 - 若当前无持仓且信号为负,则买入
TradeVolume手;若信号为零或正,则卖出同样的手数。 - 可以选择按点数同时设置止损和止盈。策略会根据
PriceStep自动换算成绝对价差,并包含 MetaTrader 在 3/5 位报价上使用的额外乘数。 - 每次决策都会写入日志,记录 RSI 输入、概率及最终信号,方便复查。
参数
| 名称 | 类型 | 默认值 | 说明 |
|---|---|---|---|
CandleType |
DataType |
1 小时时间框架 | 用于生成信号和指标数据的主蜡烛序列。 |
TradeVolume |
decimal |
1 |
每次下单的手数。 |
RsiPeriod |
int |
9 |
RSI 指标周期,同时决定历史采样间隔。 |
AppliedPrice |
AppliedPriceType |
Open |
送入 RSI 的价格类型(开盘、收盘、最高、最低、中值、典型价、加权价等)。 |
StopLossTakeProfitPips |
decimal |
100 |
止损和止盈的点数距离,填 0 可禁用保护单。 |
Weight0 … Weight7 |
decimal |
6, 96, 90, 35, 64, 83, 66, 50 |
概率网络的八个权重,取值范围 0~100。 |
与原始 MetaTrader 专家的差异
- 移除了邮件通知功能;StockSharp 的日志足以提供同样的反馈。
- 持仓量固定为
TradeVolume,未实现部分平仓或逐步加仓,与原始代码保持一致。 - 指标数据通过高级蜡烛订阅提供,无需手动调用
CopyBuffer或操作指针。 - 点值换算直接使用品种的
PriceStep,并在 3/5 位报价上自动乘以 10,而不是写死点差。
使用建议
- 启动前请确认
TradeVolume与品种的最小手数步长一致;构造函数同时将该值写入Strategy.Volume。 - 在优化过程中调节八个权重,可以让概率网络适应不同市场环境。
- 在点差较大的品种或计划手动离场时,可减小
StopLossTakeProfitPips或设为 0。 - 将策略添加到图表,可同时查看蜡烛、RSI 以及成交,便于验证神经网络输出。
指标
- 一个根据所选价格计算的
RelativeStrengthIndex。
namespace StockSharp.Samples.Strategies;
using System;
using System.Linq;
using System.Collections.Generic;
using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;
using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;
using StockSharp.Algo;
/// <summary>
/// Probabilistic strategy converted from the RNN MetaTrader expert.
/// It feeds three delayed RSI readings into the original probability lattice and
/// trades in the direction suggested by the neural network output.
/// </summary>
public class RnnProbabilityStrategy : Strategy
{
public enum AppliedPriceTypes
{
Open,
High,
Low,
Close,
Median,
Typical,
Weighted
}
private readonly StrategyParam<decimal> _tradeVolume;
private readonly StrategyParam<int> _rsiPeriod;
private readonly StrategyParam<AppliedPriceTypes> _appliedPrice;
private readonly StrategyParam<decimal> _stopLossTakeProfitPips;
private readonly StrategyParam<decimal> _weight0;
private readonly StrategyParam<decimal> _weight1;
private readonly StrategyParam<decimal> _weight2;
private readonly StrategyParam<decimal> _weight3;
private readonly StrategyParam<decimal> _weight4;
private readonly StrategyParam<decimal> _weight5;
private readonly StrategyParam<decimal> _weight6;
private readonly StrategyParam<decimal> _weight7;
private readonly StrategyParam<DataType> _candleType;
private RelativeStrengthIndex _rsi;
private readonly List<decimal> _rsiHistory = new();
private decimal _pipSize;
/// <summary>
/// Trade volume expressed in lots.
/// </summary>
public decimal TradeVolume
{
get => _tradeVolume.Value;
set => _tradeVolume.Value = value;
}
/// <summary>
/// Averaging period for the RSI indicator.
/// </summary>
public int RsiPeriod
{
get => _rsiPeriod.Value;
set => _rsiPeriod.Value = value;
}
/// <summary>
/// Price source forwarded to the RSI indicator.
/// </summary>
public AppliedPriceTypes AppliedPrice
{
get => _appliedPrice.Value;
set => _appliedPrice.Value = value;
}
/// <summary>
/// Symmetric stop-loss and take-profit distance expressed in pips.
/// </summary>
public decimal StopLossTakeProfitPips
{
get => _stopLossTakeProfitPips.Value;
set => _stopLossTakeProfitPips.Value = value;
}
/// <summary>
/// Neural network weight for the (low, low, low) RSI combination.
/// </summary>
public decimal Weight0
{
get => _weight0.Value;
set => _weight0.Value = value;
}
/// <summary>
/// Neural network weight for the (low, low, high) RSI combination.
/// </summary>
public decimal Weight1
{
get => _weight1.Value;
set => _weight1.Value = value;
}
/// <summary>
/// Neural network weight for the (low, high, low) RSI combination.
/// </summary>
public decimal Weight2
{
get => _weight2.Value;
set => _weight2.Value = value;
}
/// <summary>
/// Neural network weight for the (low, high, high) RSI combination.
/// </summary>
public decimal Weight3
{
get => _weight3.Value;
set => _weight3.Value = value;
}
/// <summary>
/// Neural network weight for the (high, low, low) RSI combination.
/// </summary>
public decimal Weight4
{
get => _weight4.Value;
set => _weight4.Value = value;
}
/// <summary>
/// Neural network weight for the (high, low, high) RSI combination.
/// </summary>
public decimal Weight5
{
get => _weight5.Value;
set => _weight5.Value = value;
}
/// <summary>
/// Neural network weight for the (high, high, low) RSI combination.
/// </summary>
public decimal Weight6
{
get => _weight6.Value;
set => _weight6.Value = value;
}
/// <summary>
/// Neural network weight for the (high, high, high) RSI combination.
/// </summary>
public decimal Weight7
{
get => _weight7.Value;
set => _weight7.Value = value;
}
/// <summary>
/// Candle series used for indicator calculations and trading decisions.
/// </summary>
public DataType CandleType
{
get => _candleType.Value;
set => _candleType.Value = value;
}
/// <summary>
/// Initializes a new instance of the <see cref="RnnProbabilityStrategy"/> class.
/// </summary>
public RnnProbabilityStrategy()
{
_tradeVolume = Param(nameof(TradeVolume), 1m)
.SetDisplay("Trade Volume", "Lot size used for each market entry.", "General")
.SetGreaterThanZero()
;
_rsiPeriod = Param(nameof(RsiPeriod), 9)
.SetDisplay("RSI Period", "Length of the RSI indicator feeding the neural network.", "Indicator")
.SetRange(2, 200)
;
_appliedPrice = Param(nameof(AppliedPrice), AppliedPriceTypes.Open)
.SetDisplay("Applied Price", "Price type forwarded to the RSI indicator.", "Indicator");
_stopLossTakeProfitPips = Param(nameof(StopLossTakeProfitPips), 100m)
.SetDisplay("Stop Loss & Take Profit (pips)", "Distance used for both stop-loss and take-profit levels.", "Risk")
.SetRange(0m, 1000m)
;
_weight0 = Param(nameof(Weight0), 6m)
.SetDisplay("Weight 0", "Probability weight applied when all RSI inputs are low.", "Model")
.SetRange(0m, 100m)
;
_weight1 = Param(nameof(Weight1), 96m)
.SetDisplay("Weight 1", "Probability weight for the (low, low, high) branch.", "Model")
.SetRange(0m, 100m)
;
_weight2 = Param(nameof(Weight2), 90m)
.SetDisplay("Weight 2", "Probability weight for the (low, high, low) branch.", "Model")
.SetRange(0m, 100m)
;
_weight3 = Param(nameof(Weight3), 35m)
.SetDisplay("Weight 3", "Probability weight for the (low, high, high) branch.", "Model")
.SetRange(0m, 100m)
;
_weight4 = Param(nameof(Weight4), 64m)
.SetDisplay("Weight 4", "Probability weight for the (high, low, low) branch.", "Model")
.SetRange(0m, 100m)
;
_weight5 = Param(nameof(Weight5), 83m)
.SetDisplay("Weight 5", "Probability weight for the (high, low, high) branch.", "Model")
.SetRange(0m, 100m)
;
_weight6 = Param(nameof(Weight6), 66m)
.SetDisplay("Weight 6", "Probability weight for the (high, high, low) branch.", "Model")
.SetRange(0m, 100m)
;
_weight7 = Param(nameof(Weight7), 50m)
.SetDisplay("Weight 7", "Probability weight for the (high, high, high) branch.", "Model")
.SetRange(0m, 100m)
;
_candleType = Param(nameof(CandleType), TimeSpan.FromHours(1).TimeFrame())
.SetDisplay("Candle Type", "Primary timeframe used for signal generation.", "General");
}
/// <inheritdoc />
public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
{
return [(Security, CandleType)];
}
/// <inheritdoc />
protected override void OnReseted()
{
base.OnReseted();
_rsi = default;
_rsiHistory.Clear();
_pipSize = 0m;
}
/// <inheritdoc />
protected override void OnStarted2(DateTime time)
{
base.OnStarted2(time);
Volume = TradeVolume;
_pipSize = CalculatePipSize();
Unit stopLossUnit = null;
Unit takeProfitUnit = null;
if (StopLossTakeProfitPips > 0m && _pipSize > 0m)
{
var distance = StopLossTakeProfitPips * _pipSize;
stopLossUnit = new Unit(distance, UnitTypes.Absolute);
takeProfitUnit = new Unit(distance, UnitTypes.Absolute);
}
if (stopLossUnit != null || takeProfitUnit != null)
{
StartProtection(
takeProfit: takeProfitUnit,
stopLoss: stopLossUnit,
isStopTrailing: false,
useMarketOrders: true);
}
_rsi = new RelativeStrengthIndex
{
Length = RsiPeriod
};
var subscription = SubscribeCandles(CandleType);
subscription.Bind(ProcessCandle).Start();
var area = CreateChartArea();
if (area != null)
{
DrawCandles(area, subscription);
DrawIndicator(area, _rsi);
DrawOwnTrades(area);
}
}
private void ProcessCandle(ICandleMessage candle)
{
if (candle.State != CandleStates.Finished)
return;
if (_rsi == null)
return;
if (RsiPeriod <= 0)
return;
var price = AppliedPrice switch
{
AppliedPriceTypes.Open => candle.OpenPrice,
AppliedPriceTypes.High => candle.HighPrice,
AppliedPriceTypes.Low => candle.LowPrice,
AppliedPriceTypes.Close => candle.ClosePrice,
AppliedPriceTypes.Median => (candle.HighPrice + candle.LowPrice) / 2m,
AppliedPriceTypes.Typical => (candle.HighPrice + candle.LowPrice + candle.ClosePrice) / 3m,
AppliedPriceTypes.Weighted => (candle.HighPrice + candle.LowPrice + 2m * candle.ClosePrice) / 4m,
_ => candle.ClosePrice,
};
var rsiIndicatorValue = _rsi.Process(new DecimalIndicatorValue(_rsi, price, candle.OpenTime) { IsFinal = true });
if (!_rsi.IsFormed || rsiIndicatorValue.IsEmpty)
return;
var rsiValue = rsiIndicatorValue.ToDecimal();
_rsiHistory.Add(rsiValue);
TrimHistory(_rsiHistory, GetHistoryLimit());
var lastIndex = _rsiHistory.Count - 1;
var delayedIndex = lastIndex - RsiPeriod;
var delayedTwiceIndex = lastIndex - (2 * RsiPeriod);
if (delayedIndex < 0 || delayedTwiceIndex < 0)
return;
var p1 = _rsiHistory[lastIndex] / 100m;
var p2 = _rsiHistory[delayedIndex] / 100m;
var p3 = _rsiHistory[delayedTwiceIndex] / 100m;
var probability = CalculateProbability(p1, p2, p3);
var signal = probability * 2m - 1m;
LogInfo($"RSI inputs: p1={p1:F4}, p2={p2:F4}, p3={p3:F4}, probability={probability:F4}, signal={signal:F4}");
if (TradeVolume <= 0m)
return;
if (signal < 0m)
{
// want long
if (Position <= 0m)
{
var vol = Math.Abs(Position) + TradeVolume;
BuyMarket(vol);
}
}
else
{
// want short
if (Position >= 0m)
{
var vol = Position + TradeVolume;
SellMarket(vol);
}
}
}
private decimal CalculateProbability(decimal p1, decimal p2, decimal p3)
{
var pn1 = 1m - p1;
var pn2 = 1m - p2;
var pn3 = 1m - p3;
var probability =
pn1 * (pn2 * (pn3 * Weight0 + p3 * Weight1) +
p2 * (pn3 * Weight2 + p3 * Weight3)) +
p1 * (pn2 * (pn3 * Weight4 + p3 * Weight5) +
p2 * (pn3 * Weight6 + p3 * Weight7));
return probability / 100m;
}
private int GetHistoryLimit()
{
return Math.Max((2 * RsiPeriod) + 5, RsiPeriod + 1);
}
private static void TrimHistory<T>(List<T> source, int maxSize)
{
if (maxSize <= 0)
return;
if (source.Count <= maxSize)
return;
var removeCount = source.Count - maxSize;
source.RemoveRange(0, removeCount);
}
private decimal CalculatePipSize()
{
if (Security == null)
return 0m;
var step = Security.PriceStep ?? 0m;
if (step <= 0m)
return 0m;
var decimals = GetDecimalPlaces(step);
if (decimals == 3 || decimals == 5)
return step * 10m;
return step;
}
private static int GetDecimalPlaces(decimal value)
{
value = Math.Abs(value);
var decimals = 0;
while (value != Math.Truncate(value) && decimals < 10)
{
value *= 10m;
decimals++;
}
return decimals;
}
}
import clr
clr.AddReference("StockSharp.Messages")
clr.AddReference("StockSharp.BusinessEntities")
clr.AddReference("StockSharp.Algo")
clr.AddReference("StockSharp.Algo.Indicators")
clr.AddReference("StockSharp.Algo.Strategies")
from System import TimeSpan, Math
from StockSharp.Messages import DataType, CandleStates, Unit, UnitTypes
from StockSharp.Algo.Indicators import RelativeStrengthIndex
from StockSharp.Algo.Strategies import Strategy
from indicator_extensions import *
class rnn_probability_strategy(Strategy):
def __init__(self):
super(rnn_probability_strategy, self).__init__()
self._trade_volume = self.Param("TradeVolume", 1.0) \
.SetDisplay("Trade Volume", "Lot size used for each market entry.", "General")
self._rsi_period = self.Param("RsiPeriod", 9) \
.SetDisplay("RSI Period", "Length of the RSI indicator feeding the neural network.", "Indicator")
self._stop_loss_take_profit_pips = self.Param("StopLossTakeProfitPips", 100.0) \
.SetDisplay("Stop Loss & Take Profit (pips)", "Distance used for both stop-loss and take-profit levels.", "Risk")
self._weight0 = self.Param("Weight0", 6.0) \
.SetDisplay("Weight 0", "Probability weight applied when all RSI inputs are low.", "Model")
self._weight1 = self.Param("Weight1", 96.0) \
.SetDisplay("Weight 1", "Probability weight for the (low, low, high) branch.", "Model")
self._weight2 = self.Param("Weight2", 90.0) \
.SetDisplay("Weight 2", "Probability weight for the (low, high, low) branch.", "Model")
self._weight3 = self.Param("Weight3", 35.0) \
.SetDisplay("Weight 3", "Probability weight for the (low, high, high) branch.", "Model")
self._weight4 = self.Param("Weight4", 64.0) \
.SetDisplay("Weight 4", "Probability weight for the (high, low, low) branch.", "Model")
self._weight5 = self.Param("Weight5", 83.0) \
.SetDisplay("Weight 5", "Probability weight for the (high, low, high) branch.", "Model")
self._weight6 = self.Param("Weight6", 66.0) \
.SetDisplay("Weight 6", "Probability weight for the (high, high, low) branch.", "Model")
self._weight7 = self.Param("Weight7", 50.0) \
.SetDisplay("Weight 7", "Probability weight for the (high, high, high) branch.", "Model")
self._candle_type = self.Param("CandleType", DataType.TimeFrame(TimeSpan.FromHours(1))) \
.SetDisplay("Candle Type", "Primary timeframe used for signal generation.", "General")
self._rsi = None
self._rsi_history = []
self._pip_size = 0.0
@property
def trade_volume(self):
return self._trade_volume.Value
@property
def rsi_period(self):
return self._rsi_period.Value
@property
def stop_loss_take_profit_pips(self):
return self._stop_loss_take_profit_pips.Value
@property
def weight0(self):
return self._weight0.Value
@property
def weight1(self):
return self._weight1.Value
@property
def weight2(self):
return self._weight2.Value
@property
def weight3(self):
return self._weight3.Value
@property
def weight4(self):
return self._weight4.Value
@property
def weight5(self):
return self._weight5.Value
@property
def weight6(self):
return self._weight6.Value
@property
def weight7(self):
return self._weight7.Value
@property
def candle_type(self):
return self._candle_type.Value
def OnReseted(self):
super(rnn_probability_strategy, self).OnReseted()
self._rsi = None
self._rsi_history = []
self._pip_size = 0.0
def OnStarted2(self, time):
super(rnn_probability_strategy, self).OnStarted2(time)
self.Volume = self.trade_volume
self._pip_size = self._calculate_pip_size()
sl_tp_pips = float(self.stop_loss_take_profit_pips)
if sl_tp_pips > 0 and self._pip_size > 0:
distance = Unit(sl_tp_pips * self._pip_size, UnitTypes.Absolute)
self.StartProtection(
takeProfit=distance,
stopLoss=distance,
isStopTrailing=False,
useMarketOrders=True)
self._rsi = RelativeStrengthIndex()
self._rsi.Length = self.rsi_period
subscription = self.SubscribeCandles(self.candle_type)
subscription.Bind(self._process_candle).Start()
def _process_candle(self, candle):
if candle.State != CandleStates.Finished:
return
if self._rsi is None:
return
rsi_period = self.rsi_period
if rsi_period <= 0:
return
price = float(candle.OpenPrice)
rsi_ind_value = process_float(self._rsi, price, candle.OpenTime, True)
if not self._rsi.IsFormed or rsi_ind_value.IsEmpty:
return
rsi_value = float(rsi_ind_value.Value)
self._rsi_history.append(rsi_value)
max_size = max(2 * rsi_period + 5, rsi_period + 1)
if len(self._rsi_history) > max_size:
self._rsi_history = self._rsi_history[-max_size:]
last_index = len(self._rsi_history) - 1
delayed_index = last_index - rsi_period
delayed_twice_index = last_index - 2 * rsi_period
if delayed_index < 0 or delayed_twice_index < 0:
return
p1 = self._rsi_history[last_index] / 100.0
p2 = self._rsi_history[delayed_index] / 100.0
p3 = self._rsi_history[delayed_twice_index] / 100.0
probability = self._calculate_probability(p1, p2, p3)
signal = probability * 2.0 - 1.0
tv = float(self.trade_volume)
if tv <= 0:
return
pos = float(self.Position)
if signal < 0:
# want long
if pos <= 0:
vol = abs(pos) + tv
self.BuyMarket(vol)
else:
# want short
if pos >= 0:
vol = pos + tv
self.SellMarket(vol)
def _calculate_probability(self, p1, p2, p3):
pn1 = 1.0 - p1
pn2 = 1.0 - p2
pn3 = 1.0 - p3
w0 = float(self.weight0)
w1 = float(self.weight1)
w2 = float(self.weight2)
w3 = float(self.weight3)
w4 = float(self.weight4)
w5 = float(self.weight5)
w6 = float(self.weight6)
w7 = float(self.weight7)
probability = (
pn1 * (pn2 * (pn3 * w0 + p3 * w1) +
p2 * (pn3 * w2 + p3 * w3)) +
p1 * (pn2 * (pn3 * w4 + p3 * w5) +
p2 * (pn3 * w6 + p3 * w7))
)
return probability / 100.0
def _calculate_pip_size(self):
sec = self.Security
if sec is None:
return 0.0
step = sec.PriceStep
if step is None or float(step) <= 0:
return 0.0
step_val = float(step)
decimals = self._get_decimal_places(step_val)
if decimals == 3 or decimals == 5:
return step_val * 10.0
return step_val
def _get_decimal_places(self, value):
value = abs(value)
decimals = 0
while value != int(value) and decimals < 10:
value *= 10.0
decimals += 1
return decimals
def CreateClone(self):
return rnn_probability_strategy()